Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8efd67d8
Commit
8efd67d8
authored
Aug 28, 2023
by
letaoqin
Browse files
v2 group finish
parent
72539dbd
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
182 additions
and
62 deletions
+182
-62
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+3
-3
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+44
-9
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+32
-21
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+103
-29
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
8efd67d8
...
@@ -198,8 +198,8 @@ using ReferenceDropoutInstance =
...
@@ -198,8 +198,8 @@ using ReferenceDropoutInstance =
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
typename
TensorV
,
typename
TensorD
,
typename
TensorD
,
typename
TensorV
,
typename
TensorS
,
typename
TensorS
,
typename
TensorP
,
typename
TensorP
,
typename
TensorZ
,
typename
TensorZ
,
...
@@ -207,8 +207,8 @@ template <typename TensorQ,
...
@@ -207,8 +207,8 @@ template <typename TensorQ,
typename
TensorLSE
=
TensorP
>
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorK
&
k_g_n_k
,
const
TensorV
&
v_g_n_o
,
const
TensorD
&
d_g_m_n
,
const
TensorD
&
d_g_m_n
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorP
&
p_g_m_n
,
...
@@ -645,8 +645,8 @@ int run(int argc, char* argv[])
...
@@ -645,8 +645,8 @@ int run(int argc, char* argv[])
// run fwd again for y, cause z_g_m_n update
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
k_g_n_k
,
v_g_n_o
,
d_g_m_n
,
d_g_m_n
,
v_g_n_o
,
alpha
,
alpha
,
s_g_m_n
,
s_g_m_n
,
p_g_m_n
,
p_g_m_n
,
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
8efd67d8
...
@@ -69,8 +69,8 @@ using AccDataType = F32;
...
@@ -69,8 +69,8 @@ using AccDataType = F32;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
// INT32
using
ZDataType
=
U16
;
// INT32
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
F16
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
...
@@ -197,6 +197,7 @@ using ReferenceDropoutInstance =
...
@@ -197,6 +197,7 @@ using ReferenceDropoutInstance =
template
<
typename
TensorQ
,
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorK
,
typename
TensorD
,
typename
TensorV
,
typename
TensorV
,
typename
TensorS
,
typename
TensorS
,
typename
TensorP
,
typename
TensorP
,
...
@@ -205,6 +206,7 @@ template <typename TensorQ,
...
@@ -205,6 +206,7 @@ template <typename TensorQ,
typename
TensorLSE
=
TensorP
>
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorK
&
k_g_n_k
,
const
TensorD
&
d_g_m_n
,
const
TensorV
&
v_g_n_o
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorS
&
s_g_m_n
,
...
@@ -225,6 +227,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -225,6 +227,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
s_g_m_n
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
// masking
// masking
auto
M
=
s_g_m_n
.
GetLengths
()[
1
];
auto
M
=
s_g_m_n
.
GetLengths
()[
1
];
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
...
@@ -319,6 +324,7 @@ int run(int argc, char* argv[])
...
@@ -319,6 +324,7 @@ int run(int argc, char* argv[])
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
const
void
*>
p_q
;
std
::
vector
<
const
void
*>
p_q
;
std
::
vector
<
const
void
*>
p_k
;
std
::
vector
<
const
void
*>
p_k
;
std
::
vector
<
const
void
*>
p_d0
;
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
const
void
*>
p_v
;
std
::
vector
<
const
void
*>
p_v
;
...
@@ -331,6 +337,7 @@ int run(int argc, char* argv[])
...
@@ -331,6 +337,7 @@ int run(int argc, char* argv[])
std
::
vector
<
Tensor
<
InputDataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
InputDataType
>>
q_g_m_ks
;
std
::
vector
<
Tensor
<
InputDataType
>>
k_g_n_ks
;
std
::
vector
<
Tensor
<
InputDataType
>>
k_g_n_ks
;
std
::
vector
<
Tensor
<
Acc0BiasDataType
>>
d0_g_m_ns
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_g_m_ns
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_g_m_ns
;
std
::
vector
<
Tensor
<
InputDataType
>>
v_g_n_os
;
std
::
vector
<
Tensor
<
InputDataType
>>
v_g_n_os
;
std
::
vector
<
Tensor
<
AccDataType
>>
s_g_m_ns
;
std
::
vector
<
Tensor
<
AccDataType
>>
s_g_m_ns
;
...
@@ -341,6 +348,7 @@ int run(int argc, char* argv[])
...
@@ -341,6 +348,7 @@ int run(int argc, char* argv[])
std
::
vector
<
Tensor
<
InputDataType
>>
q_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
q_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
k_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
k_tensors
;
std
::
vector
<
Tensor
<
Acc0BiasDataType
>>
d0_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
v_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
v_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
y_tensors
;
std
::
vector
<
Tensor
<
InputDataType
>>
y_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
...
@@ -352,6 +360,7 @@ int run(int argc, char* argv[])
...
@@ -352,6 +360,7 @@ int run(int argc, char* argv[])
std
::
vector
<
DeviceMemPtr
>
q_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
q_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
k_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
k_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
d0_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
z_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
z_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
v_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
v_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
y_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
y_tensors_device
;
...
@@ -394,6 +403,12 @@ int run(int argc, char* argv[])
...
@@ -394,6 +403,12 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// d0 layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// d0 layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
...
@@ -420,8 +435,8 @@ int run(int argc, char* argv[])
...
@@ -420,8 +435,8 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
lse_gs_ms_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases
_gs_ms_ns_lengths
}
,
d0
_gs_ms_ns_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases
_gs_ms_ns_strides
}
,
d0
_gs_ms_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
});
});
...
@@ -432,12 +447,13 @@ int run(int argc, char* argv[])
...
@@ -432,12 +447,13 @@ int run(int argc, char* argv[])
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
)
*
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
BatchCount
+
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Acc0BiasDataType
>
d0_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
InputDataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
InputDataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
...
@@ -447,6 +463,7 @@ int run(int argc, char* argv[])
...
@@ -447,6 +463,7 @@ int run(int argc, char* argv[])
{
{
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_gs_ms_ns: "
<<
d0_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
...
@@ -461,30 +478,35 @@ int run(int argc, char* argv[])
...
@@ -461,30 +478,35 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0BiasDataType
>
{
-
2
,
2
});
break
;
break
;
case
2
:
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
-
0.5
,
0.5
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_3
<
Acc0BiasDataType
>
{
-
0.5
,
0.5
});
break
;
break
;
case
3
:
case
3
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
5
,
5
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
5
,
5
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
break
;
break
;
case
4
:
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
break
;
break
;
case
5
:
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
...
@@ -492,6 +514,7 @@ int run(int argc, char* argv[])
...
@@ -492,6 +514,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -506,6 +529,7 @@ int run(int argc, char* argv[])
...
@@ -506,6 +529,7 @@ int run(int argc, char* argv[])
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -517,6 +541,7 @@ int run(int argc, char* argv[])
...
@@ -517,6 +541,7 @@ int run(int argc, char* argv[])
}
}
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Acc0BiasDataType
>
d0_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
...
@@ -531,12 +556,16 @@ int run(int argc, char* argv[])
...
@@ -531,12 +556,16 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
q_g_m_ks
.
push_back
(
q_g_m_k
);
k_g_n_ks
.
push_back
(
k_g_n_k
);
k_g_n_ks
.
push_back
(
k_g_n_k
);
d0_g_m_ns
.
push_back
(
d0_g_m_n
);
z_g_m_ns
.
push_back
(
z_g_m_n
);
z_g_m_ns
.
push_back
(
z_g_m_n
);
v_g_n_os
.
push_back
(
v_g_n_o
);
v_g_n_os
.
push_back
(
v_g_n_o
);
s_g_m_ns
.
push_back
(
s_g_m_n
);
s_g_m_ns
.
push_back
(
s_g_m_n
);
...
@@ -546,6 +575,7 @@ int run(int argc, char* argv[])
...
@@ -546,6 +575,7 @@ int run(int argc, char* argv[])
p_drop_g_m_ns
.
push_back
(
p_drop_g_m_n
);
p_drop_g_m_ns
.
push_back
(
p_drop_g_m_n
);
q_tensors
.
push_back
(
q_gs_ms_ks
);
q_tensors
.
push_back
(
q_gs_ms_ks
);
k_tensors
.
push_back
(
k_gs_ns_ks
);
k_tensors
.
push_back
(
k_gs_ns_ks
);
d0_tensors
.
push_back
(
d0_gs_ms_ns
);
v_tensors
.
push_back
(
v_gs_os_ns
);
v_tensors
.
push_back
(
v_gs_os_ns
);
y_tensors
.
push_back
(
y_gs_ms_os
);
y_tensors
.
push_back
(
y_gs_ms_os
);
z_tensors
.
push_back
(
z_gs_ms_ns
);
z_tensors
.
push_back
(
z_gs_ms_ns
);
...
@@ -555,6 +585,8 @@ int run(int argc, char* argv[])
...
@@ -555,6 +585,8 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
k_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
d0_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
GetElementSpaceSize
()));
z_tensors_device
.
emplace_back
(
z_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
GetElementSpaceSize
()));
v_tensors_device
.
emplace_back
(
v_tensors_device
.
emplace_back
(
...
@@ -573,11 +605,13 @@ int run(int argc, char* argv[])
...
@@ -573,11 +605,13 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
k_tensors_device
.
back
()
->
ToDevice
(
k_gs_ns_ks
.
data
());
k_tensors_device
.
back
()
->
ToDevice
(
k_gs_ns_ks
.
data
());
d0_tensors_device
.
back
()
->
ToDevice
(
d0_gs_ms_ns
.
data
());
z_tensors_device
.
back
()
->
ToDevice
(
z_gs_ms_ns
.
data
());
z_tensors_device
.
back
()
->
ToDevice
(
z_gs_ms_ns
.
data
());
v_tensors_device
.
back
()
->
ToDevice
(
v_gs_os_ns
.
data
());
v_tensors_device
.
back
()
->
ToDevice
(
v_gs_os_ns
.
data
());
ygrad_tensors_device
.
back
()
->
ToDevice
(
ygrad_gs_ms_os
.
data
());
ygrad_tensors_device
.
back
()
->
ToDevice
(
ygrad_gs_ms_os
.
data
());
p_q
.
push_back
(
q_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_q
.
push_back
(
q_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_k
.
push_back
(
k_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_k
.
push_back
(
k_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_d0
.
push_back
(
d0_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_z
.
push_back
(
z_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_z
.
push_back
(
z_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_z_nullptr
.
push_back
(
nullptr
);
p_z_nullptr
.
push_back
(
nullptr
);
p_v
.
push_back
(
v_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_v
.
push_back
(
v_tensors_device
.
back
()
->
GetDeviceBuffer
());
...
@@ -599,8 +633,8 @@ int run(int argc, char* argv[])
...
@@ -599,8 +633,8 @@ int run(int argc, char* argv[])
p_qgrad
,
p_qgrad
,
p_kgrad
,
p_kgrad
,
p_vgrad
,
p_vgrad
,
{},
// std::array<void*, 1> p_acc0_biases;
p_d0
,
{},
// std::array<void*, 1> p_acc1_biases;
{},
problem_descs
,
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
@@ -645,8 +679,8 @@ int run(int argc, char* argv[])
...
@@ -645,8 +679,8 @@ int run(int argc, char* argv[])
p_qgrad
,
p_qgrad
,
p_kgrad
,
p_kgrad
,
p_vgrad
,
p_vgrad
,
{},
// std::array<void*, 1> p_acc0_biases;
p_d0
,
{},
// std::array<void*, 1> p_acc1_biases;
{},
problem_descs
,
problem_descs
,
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
...
@@ -675,6 +709,7 @@ int run(int argc, char* argv[])
...
@@ -675,6 +709,7 @@ int run(int argc, char* argv[])
});
});
run_attention_fwd_host
(
q_g_m_ks
[
i
],
run_attention_fwd_host
(
q_g_m_ks
[
i
],
k_g_n_ks
[
i
],
k_g_n_ks
[
i
],
d0_g_m_ns
[
i
],
v_g_n_os
[
i
],
v_g_n_os
[
i
],
alpha
,
alpha
,
s_g_m_ns
[
i
],
s_g_m_ns
[
i
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
8efd67d8
...
@@ -26,6 +26,7 @@ namespace tensor_operation {
...
@@ -26,6 +26,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
D0DataType
,
typename
GroupKernelArg
,
typename
GroupKernelArg
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -100,6 +101,15 @@ __global__ void
...
@@ -100,6 +101,15 @@ __global__ void
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
}
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
...
@@ -107,6 +117,7 @@ __global__ void
...
@@ -107,6 +117,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
z_matrix_ptr
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
...
@@ -143,6 +154,7 @@ __global__ void
...
@@ -143,6 +154,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
tmp_p_d0_grid
,
z_matrix_ptr
,
z_matrix_ptr
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
...
@@ -258,11 +270,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -258,11 +270,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
()
;
using
D0DataType
=
Acc0BiasDataType
;
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
()
;
using
D1DataType
=
Acc1BiasDataType
;
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
is_same
<
D1DataType
,
void
>::
value
,
"Bias
1
addition is unimplemented"
);
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
;
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
;
struct
ProblemDesc
struct
ProblemDesc
...
@@ -482,6 +494,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -482,6 +494,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
lse_grid_desc_mraw
;
return
lse_grid_desc_mraw
;
}
}
}
}
// D in Gemm0 C position
static
auto
MakeDGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides_vec
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths_vec
,
d_gs_ms_ns_strides_vec
);
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
...
@@ -495,6 +513,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -495,6 +513,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeDGridDescriptor_M_N
({},
{}));
using
YGradGridDesc_O0_M_O1
=
decltype
(
MakeYGradGridDescriptor_O0_M_O1
({},
{}));
using
YGradGridDesc_O0_M_O1
=
decltype
(
MakeYGradGridDescriptor_O0_M_O1
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
...
@@ -574,6 +593,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -574,6 +593,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
D0DataType
,
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
GemmDataType
,
GemmDataType
,
...
@@ -589,6 +609,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -589,6 +609,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
KGridDesc_N_K
,
KGridDesc_N_K
,
D0GridDesc_M_N
,
ZGridDesc_M_N
,
ZGridDesc_M_N
,
B1GridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
YGridDesc_M_O
,
...
@@ -625,6 +646,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -625,6 +646,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
true
,
BBlockLdsExtraN
,
BBlockLdsExtraN
,
D0BlockTransferSrcScalarPerVector
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -706,8 +728,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -706,8 +728,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_biases
,
const
std
::
vector
<
const
void
*
>&
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_biases
,
const
std
::
vector
<
const
void
*
>&
p_acc1_biases
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -836,18 +858,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -836,18 +858,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
// so on
if
(
!
(
problem_desc
.
acc0_biases_gs_ms_ns_lengths
.
size
()
==
NumAcc0Bias
&&
problem_desc
.
acc0_biases_gs_ms_ns_strides
.
size
()
==
NumAcc0Bias
&&
problem_desc
.
acc1_biases_gs_ms_os_lengths
.
size
()
==
NumAcc1Bias
&&
problem_desc
.
acc1_biases_gs_ms_os_strides
.
size
()
==
NumAcc1Bias
))
{
throw
std
::
runtime_error
(
"wrong! number of biases in function argument does not "
"match that in template argument"
);
}
const
auto
raw_m_padded
=
GridwiseGemm
::
GetPaddedSize
(
const
auto
raw_m_padded
=
GridwiseGemm
::
GetPaddedSize
(
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
]);
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
]);
const
auto
raw_n_padded
=
GridwiseGemm
::
GetPaddedSize
(
const
auto
raw_n_padded
=
GridwiseGemm
::
GetPaddedSize
(
...
@@ -964,6 +974,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -964,6 +974,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
D0DataType
,
GroupKernelArg
,
GroupKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -1128,8 +1139,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1128,8 +1139,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_biases
,
const
std
::
vector
<
const
void
*
>&
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_biases
,
const
std
::
vector
<
const
void
*
>&
p_acc1_biases
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1176,8 +1187,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1176,8 +1187,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Qgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Kgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
std
::
vector
<
void
*>&
p_Vgrads
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>&
p_acc0_biases
,
const
std
::
vector
<
const
void
*
>&
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>&
p_acc1_biases
,
const
std
::
vector
<
const
void
*
>&
p_acc1_biases
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
const
std
::
vector
<
ProblemDesc
>&
problem_desc_vec
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
8efd67d8
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment