Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fc7e83ee
Commit
fc7e83ee
authored
Dec 27, 2022
by
Anthony Chang
Browse files
fix compiler warnings
parent
5eec2aef
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
54 deletions
+22
-54
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+12
-41
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+8
-11
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
fc7e83ee
...
...
@@ -80,7 +80,7 @@ static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatched
GemmSoftmaxGemmPermute
_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatched
MultiheadAttentionBackward
_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
...
...
@@ -665,7 +665,7 @@ int run(int argc, char* argv[])
1e-2
);
}
return
pass
?
(
std
::
cout
<<
"pass
\n
"
,
0
)
:
(
std
::
cout
<<
"fail
\n
"
,
1
);
return
pass
?
((
void
)
(
std
::
cout
<<
"pass
\n
"
)
,
0
)
:
((
void
)
(
std
::
cout
<<
"fail
\n
"
)
,
1
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
fc7e83ee
...
...
@@ -68,8 +68,6 @@ __global__ void
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
// const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args
// const KGradGridDescriptor_N_K kgrad_grid_desc_n_k,
const
VGradGridDescriptor_N_O
vgrad_grid_desc_n_o
,
const
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
block_2_ctile_map
,
...
...
@@ -207,26 +205,8 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
#if 0
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
DataType,
DataType,
DataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
#endif
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
:
public
BaseOperator
// TODO inherit atten bwd op
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
...
...
@@ -247,7 +227,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using
DeviceOp
=
DeviceBatched
GemmSoftmaxGemmPermute
_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatched
MultiheadAttentionBackward
_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -363,12 +343,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec,
// ",") << std::endl; LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ",
// v_gs_os_ns_strides_vec, ",") << std::endl; LogRangeAsType<float>(std::cout <<
// "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec,
// ",") << std::endl;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
...
...
@@ -455,7 +429,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
// static auto MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock()
//
// dQ = alpha * dS * K
...
...
@@ -643,7 +616,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
// Argument
// FIXME: constness
struct
Argument
:
public
BaseArgument
{
Argument
(
...
...
@@ -696,10 +668,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
// dV = P^T * dY
vgrad_grid_desc_n_o_
{
DeviceOp
::
MakeVGradGridDescriptor_N_O
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
/* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_
{
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
y_grid_desc_m_o_
)},
// batch offsets
a_grid_desc_g_m_k_
{
...
...
@@ -791,9 +761,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
DataType
*
p_vgrad_grid_
;
DataType
*
p_qgrad_grid_
;
DataType
*
p_kgrad_grid_
;
DataType
*
p_vgrad_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
...
@@ -801,6 +771,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
VGradGridDesc_N_O
vgrad_grid_desc_n_o_
;
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1_
;
// batch offsets
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
...
...
@@ -808,9 +782,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
VGradGridDesc_N_O
vgrad_grid_desc_n_o_
;
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1_
;
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
...
@@ -927,7 +898,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
//
override
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
...
...
@@ -1010,7 +981,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
//
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
...
...
@@ -1154,12 +1125,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
// polymorphic
std
::
string
GetTypeString
()
const
//
override
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatched
GemmSoftmaxGemmPermute
_Xdl_CShuffle"
str
<<
"DeviceBatched
MultiheadAttentionBackward
_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
fc7e83ee
...
...
@@ -788,7 +788,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index
(
-
MPerBlock
/
Gemm2Params_N_O_M
::
B_M1
,
0
,
0
);
template
<
typename
CGradDesc_N_O
>
__host__
__device__
static
const
auto
__host__
__device__
static
auto
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
const
CGradDesc_N_O
&
c_grid_desc_n_o
)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
...
...
@@ -811,7 +811,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
BlockwiseGemm
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
__host__
__device__
static
const
auto
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
__host__
__device__
static
auto
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
{
return
to_multi_index
(
BlockwiseGemm
::
CalculateCThreadOriginDataIndex8D
(
I0
,
I0
,
I0
,
I0
));
}
...
...
@@ -863,7 +863,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// things more concise
template
<
typename
YGradGridDesc_M0_O_M1_
>
__device__
static
const
auto
__device__
static
auto
MakeYGradGridDesc_O0_M_O1
(
const
YGradGridDesc_M0_O_M1_
&
ygrad_grid_desc_m0_o_m1
)
{
const
auto
M0
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I0
);
...
...
@@ -884,7 +884,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
template
<
typename
VGridDesc_N0_O_N1_
>
__device__
static
const
auto
__device__
static
auto
MakeVGridDesc_O0_N_O1
(
const
VGridDesc_N0_O_N1_
&
v_grid_desc_n0_o_n1
)
{
const
auto
N0
=
v_grid_desc_n0_o_n1
.
GetLength
(
I0
);
...
...
@@ -909,7 +909,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct
QGradGemmTile_M_K_N
{
template
<
typename
QGridDesc_K0_M_K1_
>
__device__
static
const
auto
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
__device__
static
auto
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
{
const
auto
K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
...
...
@@ -936,7 +936,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
const
auto
__device__
static
auto
MakeKGridDesc_N0_K_N1
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
...
...
@@ -961,7 +961,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
// B position
template
<
typename
QGridDesc_K0_M_K1_
>
__device__
static
const
auto
__device__
static
auto
MakeQGridDesc_M0_K_M1
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
{
const
auto
Q_K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
...
...
@@ -983,7 +983,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// C position
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
const
auto
__device__
static
auto
MakeKGradGridDesc_N_K
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
...
...
@@ -1668,9 +1668,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
auto
acc0_thread_origin
=
s_blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment