Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
496e2ec6
"vscode:/vscode.git/clone" did not exist on "353b228d09c54f6e385045c82f53befb38ad8738"
Commit
496e2ec6
authored
Nov 19, 2021
by
Chao Liu
Browse files
move C pointwise operation into threadwise copy
parent
f0201ead
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
48 additions
and
60 deletions
+48
-60
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+15
-13
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+8
-3
device_operation/include/device_base.hpp
device_operation/include/device_base.hpp
+0
-9
device_operation/include/device_gemm.hpp
device_operation/include/device_gemm.hpp
+4
-2
device_operation/include/device_gemm_xdl.hpp
device_operation/include/device_gemm_xdl.hpp
+19
-31
example/1_gemm_xdl/gemm_xdl.cpp
example/1_gemm_xdl/gemm_xdl.cpp
+2
-2
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
496e2ec6
...
...
@@ -34,7 +34,7 @@ __global__ void
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
Block2CTileMap
block_2_ctile_map
,
const
CElementwiseOperation
c_element
wise
_op
)
const
CElementwiseOperation
c_element_op
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
@@ -49,7 +49,7 @@ __global__ void
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
block_2_ctile_map
,
c_element
wise
_op
);
c_element_op
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template
<
typename
GridwiseGemm
,
...
...
@@ -58,7 +58,8 @@ template <typename GridwiseGemm,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
Block2CTileMap
>
typename
Block2CTileMap
,
typename
CElementwiseOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -69,7 +70,8 @@ __global__ void
const
void
CONSTANT
*
p_a_grid_desc_k0_m_k1
,
const
void
CONSTANT
*
p_b_grid_desc_k0_n_k1
,
const
void
CONSTANT
*
p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
void
CONSTANT
*
p_block_2_ctile_map
)
const
void
CONSTANT
*
p_block_2_ctile_map
,
const
void
CONSTANT
*
p_c_element_op
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
@@ -83,6 +85,8 @@ __global__ void
cast_pointer_to_generic_address_space
(
p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
));
const
auto
block_2_ctile_map
=
*
reinterpret_cast
<
const
Block2CTileMap
*>
(
cast_pointer_to_generic_address_space
(
p_block_2_ctile_map
));
const
auto
c_element_op
=
*
reinterpret_cast
<
const
CElementwiseOperation
*>
(
cast_pointer_to_generic_address_space
(
p_c_element_op
));
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
...
...
@@ -93,7 +97,8 @@ __global__ void
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
block_2_ctile_map
);
block_2_ctile_map
,
c_element_op
);
}
#endif
...
...
@@ -105,7 +110,7 @@ template <index_t BlockSize,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
CElementwiseOp
,
typename
CElementwiseOp
eration
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
...
...
@@ -358,7 +363,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
CElementwiseOp
&
c_element
wise
_op
)
const
CElementwiseOp
eration
&
c_element_op
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
...
...
@@ -578,10 +583,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
M0
>
{},
Number
<
N0
>
{},
I1
,
I1
,
Number
<
M2
>
{},
I1
,
Number
<
M4
>
{},
I1
));
// elementwise Op to C
static_for
<
0
,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
(),
1
>
{}(
[
&
](
auto
i
)
{
c_thread_buf
(
i
)
=
c_elementwise_op
(
c_thread_buf
[
i
]);
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
...
...
@@ -619,6 +620,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatC
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
CElementwiseOperation
,
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
...
...
@@ -626,7 +628,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
...
...
@@ -635,7 +636,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
])};
n_thread_data_on_grid_idx
[
I2
]),
c_element_op
};
c_thread_copy
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
496e2ec6
...
...
@@ -50,6 +50,7 @@ template <typename SrcData,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOp
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
...
...
@@ -69,8 +70,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
))
const
Index
&
dst_slice_origin_idx
,
const
ElementwiseOp
element_op
)
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
element_op_
{
element_op
}
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
...
...
@@ -195,8 +198,9 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
// apply element-wise operation and type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src_buf
[
Number
<
src_offset
>
{}]);
type_convert
<
DstData
>
(
element_op_
(
src_buf
[
Number
<
src_offset
>
{}])
)
;
});
const
bool
is_dst_valid
=
...
...
@@ -373,6 +377,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private:
DstCoord
dst_coord_
;
ElementwiseOp
element_op_
;
};
// namespace ck
// Assume:
...
...
device_operation/include/device_base.hpp
View file @
496e2ec6
...
...
@@ -36,15 +36,6 @@ struct BaseOperator
virtual
~
BaseOperator
()
{}
};
struct
BaseGpuOperator
{
BaseGpuOperator
()
=
default
;
BaseGpuOperator
(
const
BaseGpuOperator
&
)
=
default
;
BaseGpuOperator
&
operator
=
(
const
BaseGpuOperator
&
)
=
default
;
virtual
~
BaseGpuOperator
()
{}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
device_operation/include/device_gemm.hpp
View file @
496e2ec6
...
...
@@ -8,6 +8,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
CElementwiseOperation
>
struct
DeviceGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
@@ -20,12 +21,13 @@ struct DeviceGemm : public BaseOperator
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
std
::
unique_ptr
<
BaseGpu
Operat
or
>
c_element_op
_ptr
)
=
0
;
CElementwise
Operat
ion
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
using
DeviceGemmPtr
=
std
::
unique_ptr
<
DeviceGemm
>
;
template
<
typename
CElementwiseOperation
>
using
DeviceGemmPtr
=
std
::
unique_ptr
<
DeviceGemm
<
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
device_operation/include/device_gemm_xdl.hpp
View file @
496e2ec6
...
...
@@ -50,7 +50,7 @@ template <typename ADataType,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
struct
DeviceGemmXdl
:
public
DeviceGemm
struct
DeviceGemmXdl
:
public
DeviceGemm
<
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -233,7 +233,7 @@ struct DeviceGemmXdl : public DeviceGemm
index_t
StrideC
,
index_t
M01
,
index_t
N01
,
CElementwiseOperation
c_element
wise
_op
)
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
...
...
@@ -244,7 +244,7 @@ struct DeviceGemmXdl : public DeviceGemm
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
c_element
wise
_op_
{
c_element
wise
_op
}
c_element_op_
{
c_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
...
...
@@ -271,7 +271,7 @@ struct DeviceGemmXdl : public DeviceGemm
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
CElementwiseOperation
c_element
wise
_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
...
...
@@ -337,7 +337,7 @@ struct DeviceGemmXdl : public DeviceGemm
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
block_2_ctile_map_
,
arg
.
c_element
wise
_op_
);
arg
.
c_element_op_
);
}
else
{
...
...
@@ -364,7 +364,7 @@ struct DeviceGemmXdl : public DeviceGemm
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
block_2_ctile_map_
,
arg
.
c_element
wise
_op_
);
arg
.
c_element_op_
);
}
return
ave_time
;
...
...
@@ -407,27 +407,15 @@ struct DeviceGemmXdl : public DeviceGemm
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
std
::
unique_ptr
<
BaseGpuOperator
>
c_op_ptr
)
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
*
dynamic_cast
<
CElementwiseOperation
*>
(
c_op_ptr
.
get
())};
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
...
...
@@ -436,7 +424,7 @@ struct DeviceGemmXdl : public DeviceGemm
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
std
::
unique_ptr
<
BaseGpuOperator
>
c_op_ptr
)
override
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -449,7 +437,7 @@ struct DeviceGemmXdl : public DeviceGemm
StrideC
,
1
,
1
,
*
dynamic_cast
<
CElementwiseOperation
*>
(
c_op_ptr
.
get
())
);
c_element_op
);
}
// polymorphic
...
...
example/1_gemm_xdl/gemm_xdl.cpp
View file @
496e2ec6
...
...
@@ -14,7 +14,7 @@
#include "device_base.hpp"
#include "device_gemm_xdl.hpp"
struct
Activation
:
public
ck
::
tensor_operation
::
device
::
BaseGpuOperator
struct
Activation
{
float
alpha
=
0.1
;
...
...
@@ -191,7 +191,7 @@ int main(int argc, char* argv[])
StrideA
,
StrideB
,
StrideC
,
std
::
make_unique
<
Activation
>
(
activation
)
)
;
activation
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment