Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b67a58c0
"vscode:/vscode.git/clone" did not exist on "e8ab4ba33785c348235c1ef932c7671e7e64687d"
Commit
b67a58c0
authored
Nov 29, 2022
by
Anthony Chang
Browse files
can validate dV with relaxed error tolerance
parent
8551dd43
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
513 additions
and
264 deletions
+513
-264
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+52
-57
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+18
-2
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+18
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+88
-10
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+326
-191
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+1
-0
include/ck/utility/thread_group.hpp
include/ck/utility/thread_group.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
...rary/reference_tensor_operation/cpu/reference_softmax.hpp
+7
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+1
-1
test/space_filling_curve/CMakeLists.txt
test/space_filling_curve/CMakeLists.txt
+0
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
b67a58c0
...
...
@@ -15,7 +15,9 @@ Outputs:
*/
#define PRINT_HOST 1
#pragma clang diagnostic ignored "-Wunused-variable"
#define PRINT_HOST 0
#include <iostream>
#include <numeric>
...
...
@@ -50,6 +52,7 @@ using YElementOp = PassThrough;
using
DataType
=
F16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
...
@@ -76,6 +79,7 @@ using DeviceGemmInstance =
NumDimK
,
NumDimO
,
DataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
...
...
@@ -172,14 +176,16 @@ template <typename TensorQ,
typename
TensorV
,
typename
TensorS
,
typename
TensorP
,
typename
TensorY
>
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
)
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
)
{
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
...
...
@@ -207,7 +213,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
});
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
}
,
&
lse_g_m
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
...
...
@@ -230,10 +236,10 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
4
;
ck
::
index_t
N
=
4
;
ck
::
index_t
K
=
4
;
ck
::
index_t
O
=
4
;
ck
::
index_t
M
=
256
;
ck
::
index_t
N
=
256
;
ck
::
index_t
K
=
256
;
ck
::
index_t
O
=
256
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
G1
=
1
;
...
...
@@ -242,8 +248,6 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
output_permute
=
false
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
if
(
argc
==
1
)
{
// use default case
...
...
@@ -283,6 +287,8 @@ int run(int argc, char* argv[])
exit
(
0
);
}
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
...
...
@@ -307,16 +313,27 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -340,19 +357,20 @@ int run(int argc, char* argv[])
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
break
;
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_
Diagonal
<
DataType
>
{});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_
1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_
1
<
DataType
>
{
10
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_
Sequential
<
2
>
{});
// dy[g0, g1, m, n] = m
}
// calculate y beforehand
// calculate y
& log-sum-exp
beforehand
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
...
...
@@ -360,27 +378,36 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
);
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
);
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
// TODO ANT: make sure K/V gradients are zeroed
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
// TODO ANT: attention backward kernel
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
...
...
@@ -388,6 +415,7 @@ int run(int argc, char* argv[])
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
...
...
@@ -402,6 +430,7 @@ int run(int argc, char* argv[])
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
...
@@ -421,6 +450,7 @@ int run(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
// TODO ANT: add dQ/dK/dV flops & bytes
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
...
...
@@ -445,7 +475,7 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
*
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
if
(
PRINT_HOST
)
...
...
@@ -456,44 +486,6 @@ int run(int argc, char* argv[])
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
}
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
q_g_m_k
,
k_g_k_n
,
s_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
#if 0
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
#endif
// P = Softmax(S)
// >>> scipy.special.softmax(numpy.eye(4), 1)
// array([[0.47536689, 0.1748777 , 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.47536689, 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.47536689, 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
});
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// Y = P * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// Gradients
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
...
...
@@ -613,7 +605,10 @@ int run(int argc, char* argv[])
kgrad_gs_ns_ks_host_result
.
mData
);
std
::
cout
<<
"Checking vgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
);
vgrad_gs_os_ns_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
}
return
pass
?
0
:
1
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
b67a58c0
...
...
@@ -50,7 +50,8 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
,
bool
TransposeC
=
false
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
...
...
@@ -226,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
View file @
b67a58c0
...
...
@@ -108,6 +108,24 @@ struct BlockwiseSoftmax
});
}
template
<
typename
CThreadBuffer
,
typename
LSEBuffer
>
__host__
__device__
void
RunWithPreCalcStats
(
CThreadBuffer
&
in_thread_buf
,
const
LSEBuffer
&
lse_thread_buf
)
{
// calculate exp for elements using pre-calculated stats LSE (log-sum-exp)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
IgnoreNaN
&&
ck
::
math
::
isnan
(
in_thread_buf
[
offset
])
?
0
:
math
::
exp
(
in_thread_buf
[
offset
]
-
lse_thread_buf
[
iM
]);
});
});
}
BufferType
max_value_buf
;
BufferType
sum_value_buf
;
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
b67a58c0
...
...
@@ -18,12 +18,15 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
DataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
...
...
@@ -33,6 +36,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
LSEGridDescriptor_M
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_M0_O_M1
,
typename
Block2CTileMap
,
...
...
@@ -48,6 +52,7 @@ __global__ void
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
...
...
@@ -62,6 +67,7 @@ __global__ void
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
// const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args
// const KGradGridDescriptor_N_K kgrad_grid_desc_n_k,
const
VGradGridDescriptor_N_O
vgrad_grid_desc_n_o
,
...
...
@@ -87,11 +93,14 @@ __global__ void
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
...
...
@@ -106,6 +115,7 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
vgrad_grid_desc_n_o
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
...
...
@@ -140,6 +150,7 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
GemmAccDataType
,
...
...
@@ -350,9 +361,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ", v_gs_os_ns_strides_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec, ",") << std::endl;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
Pad
M
,
padder
.
PadO
>
{});
Sequence
<
padder
.
Pad
N
,
padder
.
PadO
>
{});
}
template
<
typename
YGridDesc_M_O
,
typename
Number
>
...
...
@@ -415,10 +430,36 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths_vec
,
k_gs_ns_ks_strides_vec
);
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
lse_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M
return
transform_tensor_descriptor
(
lse_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad M
return
lse_grid_desc_mraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
...
...
@@ -446,11 +487,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
BatchStrideLSE_
(
BatchStrideLSE
)
{
}
...
...
@@ -474,16 +517,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetLSEBasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
index_t
BatchStrideLSE_
;
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
DataType
,
// TODO: distinguish A/B datatype
LSEDataType
,
GemmAccDataType
,
CShuffleDataType
,
AElementwiseOperation
,
...
...
@@ -496,6 +546,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
LSEGridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -552,6 +603,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
DataType
*
p_b_grid
,
const
DataType
*
p_b1_grid
,
const
DataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
...
...
@@ -566,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
...
...
@@ -581,6 +634,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_lse_grid_
{
p_lse_grid
},
p_ygrad_grid_
{
p_ygrad_grid
},
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
...
...
@@ -590,9 +647,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
// dV = P^T * dY
vgrad_grid_desc_n_o_
{
DeviceOp
::
MakeVGradGridDescriptor_N_O
(
c_gs_ms_gemm1ns_lengths
,
c_gs_m
s_gemm1
n
s_strides
)},
vgrad_grid_desc_n_o_
{
DeviceOp
::
MakeVGradGridDescriptor_N_O
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1n
s_gemm1
k
s_strides
)},
/* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_
{
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
c_grid_desc_m_n_
,
Number
<
Y_M1
>
{})},
...
...
@@ -627,7 +685,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
}
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())
}
{
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
...
...
@@ -647,6 +705,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
Print
();
}
void
Print
()
const
...
...
@@ -659,14 +718,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_
n_k
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"b1_grid_desc_g_
o_n
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
std
::
cout
<<
"c_grid_desc_g_m_
n
_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"c_grid_desc_g_m_
o
_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
// c_grid_desc_g_m_n_.Print();
std
::
cout
<<
"vgrad_grid_desc_n_o_: "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I0
)
<<
", "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I1
)
<<
'\n'
;
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
}
// pointers
...
...
@@ -674,6 +737,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
DataType
*
p_b_grid_
;
const
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
DataType
*
p_vgrad_grid_
;
DataType
*
p_qgrad_grid_
;
...
...
@@ -684,6 +748,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
LSEGridDesc_M
lse_grid_desc_m_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
...
...
@@ -732,7 +797,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
std
::
cout
<<
"grid size = "
<<
grid_size
<<
'\n'
;
// Gemm0_K
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
...
@@ -743,6 +808,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
DataType
,
LSEDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
...
...
@@ -752,6 +818,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
VGradGridDesc_N_O
,
DeviceOp
::
YGradGridDesc_M0_O_M1
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
...
...
@@ -768,6 +835,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg
.
p_b_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
...
...
@@ -781,6 +849,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
vgrad_grid_desc_n_o_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
block_2_ctile_map_
,
...
...
@@ -791,6 +860,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
#if 1
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
...
...
@@ -799,7 +869,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
#endif
return
ave_time
;
}
...
...
@@ -898,6 +968,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
DataType
*
p_b
,
const
DataType
*
p_b1
,
const
DataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
...
...
@@ -912,6 +983,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
...
...
@@ -928,6 +1000,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b
,
p_b1
,
p_c
,
p_lse
,
p_ygrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
...
...
@@ -942,6 +1015,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
...
...
@@ -962,6 +1036,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
void
*
p_b
,
const
void
*
p_b1
,
const
void
*
p_c
,
const
void
*
p_lse
,
const
void
*
p_ygrad_grid
,
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
...
...
@@ -976,6 +1051,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
...
...
@@ -992,6 +1068,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
static_cast
<
DataType
*>
(
p_qgrad_grid
),
static_cast
<
DataType
*>
(
p_kgrad_grid
),
...
...
@@ -1006,6 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
b67a58c0
...
...
@@ -21,6 +21,7 @@ namespace ck {
template
<
typename
DataType
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatLSE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
...
...
@@ -31,6 +32,7 @@ template <typename DataType,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -170,7 +172,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
index_t
n
=
Free0_N
-
1
;
constexpr
index_t
n2
=
n
%
NPerXdl
;
constexpr
index_t
n1
=
n
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
n0
=
n
/
NPerXdl
/
Gemm0NWaves
%
M
XdlPerWave
;
constexpr
index_t
n0
=
n
/
NPerXdl
/
Gemm0NWaves
%
N
XdlPerWave
;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
...
...
@@ -178,6 +180,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
>
{}
+
Sequence
<
1
,
1
,
1
,
1
,
1
,
1
>
{};
}
__host__
__device__
static
constexpr
auto
GetPBlockSliceLengths_M0_N0_M1_N1
()
{
return
generate_sequence_v2
(
[](
auto
I
)
{
return
GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2
().
At
(
I
);
},
Number
<
4
>
{});
}
// template <typename PBlockDesc_M0_N_M1>
// __host__ __device__ static constexpr auto
// MakePMmaTileDescriptor_N0_N1_N2_M(const PBlockDesc_M0_N_M1&)
...
...
@@ -303,13 +312,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
vgrad_gemm_bytes_end
=
(
SharedMemTrait
::
p_block_space_size_aligned
+
SharedMemTrait
::
ygrad_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
vgrad_gemm_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -395,6 +412,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
__host__
__device__
static
constexpr
auto
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
const
LSEGridDesc_M
&
lse_grid_desc_m
)
{
const
index_t
M
=
lse_grid_desc_m
.
GetLength
(
I0
);
const
index_t
MBlock
=
M
/
MPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
const
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
transform_tensor_descriptor
(
lse_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MXdlPerWave
>
{},
MWave
,
Number
<
MPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}));
return
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
...
...
@@ -418,8 +451,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
p_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
static
constexpr
auto
ygrad_block_desc_m0_o_m1
=
VGradGemmTile_N_O_M
::
GetYGradBlockDescriptor_M0_O_M1
();
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
B1K1
)
;
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{}
;
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
...
...
@@ -427,10 +464,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_block_desc_m0_n_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
p_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
p_block_space_size_aligned
.
value
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
...
...
@@ -454,6 +497,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
...
...
@@ -469,6 +513,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
VGradGridDescriptor_N_O
&
vgrad_grid_desc_n_o
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
&
block_2_ctile_map
,
...
...
@@ -482,6 +527,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
vgrad_grid_desc_n_o
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
...
...
@@ -830,6 +881,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
lse_grid_desc_m
);
constexpr
auto
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
m0
,
m1
,
m2
));
auto
lse_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatLSE
>
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
auto
acc0_thread_origin
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
auto
lse_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatLSE
,
FloatLSE
,
decltype
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
),
Sequence
<
1
,
m0
,
m1
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
m2
,
1
,
false
>
{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
//
// dV
//
...
...
@@ -837,11 +917,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr
auto
p_
src_
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
constexpr
auto
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
p_dst_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
constexpr
auto
p_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
constexpr
auto
p_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
...
...
@@ -854,30 +933,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
P_N3
=
p_block_lengths
[
I6
];
constexpr
auto
P_N4
=
p_block_lengths
[
I7
];
constexpr
auto
p_
dst_
block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
[
&
]()
constexpr
constexpr
auto
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
[
&
]()
constexpr
{
constexpr
auto
p_
dst_
block_desc_m_n
=
transform_tensor_descriptor
(
p_
dst_
block_desc_m0_n_m1
,
constexpr
auto
p_block_desc_m_n
=
transform_tensor_descriptor
(
p_block_desc_m0_n_m1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
VGradGemmTile_N_O_M
::
P_M0
,
VGradGemmTile_N_O_M
::
P_M1
)),
make_pass_through_transform
(
VGradGemmTile_N_O_M
::
Free0_N
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return
transform_tensor_descriptor
(
p_
dst_
block_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
P_M0
,
P_M1
,
P_M2
)),
make_unmerge_transform
(
make_tuple
(
P_N0
,
P_N1
,
P_N2
,
P_N3
,
P_N4
))),
p_block_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
P_M1
,
P_M2
)),
make_unmerge_transform
(
make_tuple
(
I1
,
P_N1
,
P_N2
,
P_N3
,
P_N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
}
();
// TODO ANT: check lds offset
auto
p_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
),
p_dst_block_desc_m0_n_m1
.
GetElementSpaceSize
());
const
auto
p_dst_thread_origin
=
[
&
]()
{
const
auto
p_thread_origin_nd_idx_on_block
=
[
&
]()
{
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
...
...
@@ -904,8 +981,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
return
make_tuple
(
0
,
// mrepeat
0
,
// nrepeat
return
make_tuple
(
m_thread_data_on_block_idx
[
I0
],
// mrepeat
n_thread_data_on_block_idx
[
I0
],
// nrepeat
m_thread_data_on_block_idx
[
I1
],
// mwave
n_thread_data_on_block_idx
[
I1
],
// nwave
m_thread_data_on_block_idx
[
I2
],
// xdlops
...
...
@@ -914,19 +991,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
n_thread_data_on_block_idx
[
I4
]);
}();
constexpr
auto
p_block_slice_lengths_m0_n0_m1_n1
_m2_n2
=
// mrepeat, nrepeat, mwaves,
// nwaves, mperxdl, nperxdl
VGradGemmTile_N_O_M
::
GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2
();
constexpr
auto
p_block_slice_lengths_m0_n0_m1_n1
=
VGradGemmTile_N_O_M
::
GetPBlockSliceLengths_M0_N0_M1_N1
();
// mrepeat, nrepeat,
// mwaves, nwaves,
// how to properly perform copy for a sub-workgroup?
auto
p_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
p_
src_
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
p_
dst_
block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
p_block_slice_lengths_m0_n0_m1_n1
_m2_n2
[
I0
],
// ThreadSliceLengths
p_block_slice_lengths_m0_n0_m1_n1
_m2_n2
[
I1
],
Sequence
<
p_block_slice_lengths_m0_n0_m1_n1
[
I0
],
// ThreadSliceLengths
p_block_slice_lengths_m0_n0_m1_n1
[
I1
],
I1
,
I1
,
I1
,
...
...
@@ -939,21 +1016,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
p_dst_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
p_dst_thread_origin
[
I0
],
p_dst_thread_origin
[
I1
],
p_dst_thread_origin
[
I2
]
%
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I2
],
p_dst_thread_origin
[
I3
]
%
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I3
],
p_dst_thread_origin
[
I4
],
p_dst_thread_origin
[
I5
],
p_dst_thread_origin
[
I6
],
p_dst_thread_origin
[
I7
]),
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
p_thread_origin_nd_idx_on_block
[
I0
],
p_thread_origin_nd_idx_on_block
[
I1
],
p_thread_origin_nd_idx_on_block
[
I2
]
%
p_block_slice_lengths_m0_n0_m1_n1
[
I2
],
p_thread_origin_nd_idx_on_block
[
I3
]
%
p_block_slice_lengths_m0_n0_m1_n1
[
I3
],
p_thread_origin_nd_idx_on_block
[
I4
],
p_thread_origin_nd_idx_on_block
[
I5
],
p_thread_origin_nd_idx_on_block
[
I6
],
p_thread_origin_nd_idx_on_block
[
I7
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
// construct space filling curve
// p_thread_copy_vgpr_to_lds.Run();
constexpr
auto
ygrad_dst_block_desc_m0_o_m1
=
// Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0],
// p_block_slice_lengths_m0_n0_m1_n1[I1],
// I1,
// I1,
// I1,
// P_N2,
// I1,
// P_N4>{}
// .foo();
// 1, 4, 1, 1, 1, 4, 1, 4
constexpr
auto
sfc_p_m0_n0_m1_n1_m2_n2
=
SpaceFillingCurve
<
Sequence
<
P_M0
,
P_N0
,
P_M1
,
P_N1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
decltype
(
p_block_slice_lengths_m0_n0_m1_n1
)
>
{};
constexpr
auto
ygrad_block_desc_m0_o_m1
=
VGradGemmTile_N_O_M
::
GetYGradBlockDescriptor_M0_O_M1
();
auto
ygrad_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
...
...
@@ -967,7 +1058,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
DataType
,
DataType
,
decltype
(
ygrad_grid_desc_m0_o_m1
),
decltype
(
ygrad_
dst_
block_desc_m0_o_m1
),
decltype
(
ygrad_block_desc_m0_o_m1
),
typename
VGradGemmTile_N_O_M
::
YGrad_ThreadClusterArrangeOrder
,
// access order == thread
// order
Sequence
<
1
,
0
,
2
>
,
...
...
@@ -980,113 +1071,165 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true
,
true
,
1
>
(
ygrad_grid_desc_m0_o_m1
,
make_multi_index
(
0
,
gemm1_n_block_data_idx_on_grid
,
0
),
make_multi_index
(
m_block_data_idx_on_grid
/
VGradGemmTile_N_O_M
::
YGrad_M1
,
gemm1_n_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
ygrad_
dst_
block_desc_m0_o_m1
,
ygrad_block_desc_m0_o_m1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
p_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
p_block_space_offset
,
p_block_desc_m0_n_m1
.
GetElementSpaceSize
());
auto
ygrad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
p_
dst_
block_desc_m0_n_m1
),
decltype
(
ygrad_
dst_
block_desc_m0_o_m1
),
decltype
(
p_block_desc_m0_n_m1
),
decltype
(
ygrad_block_desc_m0_o_m1
),
MPerXdl
,
NPerXdl
,
VGradGemmTile_N_O_M
::
GemmNRepeat
,
// NRepeat
VGradGemmTile_N_O_M
::
GemmORepeat
,
// ORepeat
VGradGemmTile_N_O_M
::
GemmMPack
>
{};
VGradGemmTile_N_O_M
::
GemmNRepeat
,
VGradGemmTile_N_O_M
::
GemmORepeat
,
VGradGemmTile_N_O_M
::
GemmMPack
,
true
>
{};
// TranspossC
constexpr
auto
vgrad_block_lengths
=
vgrad_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
auto
vgrad_acc_thread_buf
=
vgrad_blockwise_gemm
.
GetCThreadBuffer
();
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
=
transform_tensor_descriptor
(
vgrad_grid_desc_n_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
// may place a dummy variable
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I2
],
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I4
])),
make_unmerge_transform
(
make_tuple
(
I1
,
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I3
],
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I5
]))),
VGradGemmTile_N_O_M
::
GemmNWave
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmOWave
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
constexpr
auto
vgrad_thread_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
=
vgrad_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_
M3_M4
_N
2
();
constexpr
auto
vgrad_thread_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
=
vgrad_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_
N2_N3
_N
4
();
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
=
vgrad_blockwise_gemm
.
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_
M3_M4
_N
2
(
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
=
vgrad_blockwise_gemm
.
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_
N2_N3
_N
4
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
);
const
auto
vgrad_thread_mtx_on_block_n_o
=
vgrad_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
=
decltype
(
vgrad_blockwise_gemm
)
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_
M3_M4
_N
2
();
constexpr
auto
VGrad_N0
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I0
);
constexpr
auto
VGrad_O0
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I1
);
constexpr
auto
VGrad_N1
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I2
);
constexpr
auto
VGrad_O1
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I3
);
constexpr
auto
VGrad_N2
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I4
);
constexpr
auto
VGrad_
N3
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I5
);
constexpr
auto
VGrad_
N4
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I6
);
constexpr
auto
VGrad_O
2
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I7
);
constexpr
auto
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
=
decltype
(
vgrad_blockwise_gemm
)
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_
N2_N3
_N
4
();
constexpr
auto
VGrad_N0
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I0
);
constexpr
auto
VGrad_O0
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I1
);
constexpr
auto
VGrad_N1
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I2
);
constexpr
auto
VGrad_O1
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I3
);
constexpr
auto
VGrad_N2
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I4
);
constexpr
auto
VGrad_
O2
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I5
);
constexpr
auto
VGrad_
O3
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I6
);
constexpr
auto
VGrad_O
4
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I7
);
const
index_t
n_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I0
];
// TODO ANT: step n after each Gemm1 outer loop
const
index_t
n_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I0
];
const
index_t
o_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I1
]
+
gemm1_n_block_data_idx_on_grid
;
const
auto
n_thread_data_on_grid_to_n0_n1_n2_
n3_n4_
adaptor
=
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_N0
,
VGrad_N1
,
VGrad_N2
,
VGrad_N3
,
VGrad_N4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
VGrad_N0
,
VGrad_N1
,
VGrad_N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_nd_idx_on_grid
=
n_thread_data_on_grid_to_n0_n1_n2_
n3_n4_
adaptor
.
CalculateBottomIndex
(
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_idx_on_grid
));
const
auto
o_thread_data_on_grid_to_o0_o1_o2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_O0
,
VGrad_O1
,
VGrad_O2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
const
auto
o_thread_data_on_grid_to_o0_o1_o2_o3_o4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_O0
,
VGrad_O1
,
VGrad_O2
,
VGrad_O3
,
VGrad_O4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
o_thread_data_nd_idx_on_grid
=
o_thread_data_on_grid_to_o0_o1_o2_adaptor
.
CalculateBottomIndex
(
o_thread_data_on_grid_to_o0_o1_o2_
o3_o4_
adaptor
.
CalculateBottomIndex
(
make_multi_index
(
o_thread_data_idx_on_grid
));
auto
vgrad_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
DataType
,
decltype
(
vgrad_thread_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
),
decltype
(
vgrad_grid_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
),
decltype
(
vgrad_thread_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
),
decltype
(
vgrad_grid_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
),
tensor_operation
::
element_wise
::
PassThrough
,
// CElementwiseOperation
decltype
(
vgrad_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
.
GetLengths
()),
// SliceLengths
decltype
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
1
,
// ScalarPerVector
2
,
// ScalarPerVector
InMemoryDataOperationEnum
::
AtomicAdd
,
// GlobalMemoryDataOperation
1
,
true
>
(
vgrad_grid_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
,
1
,
// DstScalarStrideInVector
true
>
(
vgrad_grid_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
,
make_multi_index
(
n_thread_data_nd_idx_on_grid
[
I0
],
o_thread_data_nd_idx_on_grid
[
I0
],
n_thread_data_nd_idx_on_grid
[
I1
],
o_thread_data_nd_idx_on_grid
[
I1
],
n_thread_data_nd_idx_on_grid
[
I2
],
n
_thread_data_nd_idx_on_grid
[
I
3
],
n
_thread_data_nd_idx_on_grid
[
I
4
],
o_thread_data_nd_idx_on_grid
[
I
2
]),
o
_thread_data_nd_idx_on_grid
[
I
2
],
o
_thread_data_nd_idx_on_grid
[
I
3
],
o_thread_data_nd_idx_on_grid
[
I
4
]),
tensor_operation
::
element_wise
::
PassThrough
{});
// TODO ANT: ygrad slice window step size
#if 0
if(hipThreadIdx_x % 32 < 4)
{
printf("wid %zd tid %zd _n0_o0_n1_o1_n2_o2_o3_o4 %d %d %d %d %d %d %d %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
n_thread_data_nd_idx_on_grid[I0],
o_thread_data_nd_idx_on_grid[I0],
n_thread_data_nd_idx_on_grid[I1],
o_thread_data_nd_idx_on_grid[I1],
n_thread_data_nd_idx_on_grid[I2],
o_thread_data_nd_idx_on_grid[I2],
o_thread_data_nd_idx_on_grid[I3],
o_thread_data_nd_idx_on_grid[I4]);
}
#endif
// p_thread_slice_copy_step will be in for loop
constexpr
auto
ygrad_block_slice_copy_step
=
make_multi_index
(
VGradGemmTile_N_O_M
::
YGrad_M0
,
0
,
0
);
constexpr
auto
ygrad_block_reset_copy_step
=
make_multi_index
(
-
MPerBlock
/
VGradGemmTile_N_O_M
::
YGrad_M1
,
0
,
0
);
// vgrad gemm output tile
const
auto
vgrad_block_slice_copy_step
=
make_multi_index
(
VGradGemmTile_N_O_M
::
GemmNRepeat
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
#if 0
if(hipThreadIdx_x == 0)
{
printf("bid %zd, n_grid = %d, o_grid = %d, step N0 = %d\n",
hipBlockIdx_x,
n_thread_data_idx_on_grid,
o_thread_data_idx_on_grid,
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(NPerBlock))[I0]);
}
#endif
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
VGradGemmTile_N_O_M
::
Sum_M
;
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_buf
,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
lse_thread_buf
);
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
do
...
...
@@ -1178,131 +1321,123 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
blockwise_softmax.Run(acc_thread_buf, workspace_buf);
// TODO: may convert to log domain
running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
// gemm1
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// Initialize acc1
acc1_thread_buf.Clear();
// preload data into LDS
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
block_sync_lds(); // wait for reduction LDS read
printf("tid %zd, S[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x,
acc_thread_buf[I0],
acc_thread_buf[I1],
acc_thread_buf[I2],
acc_thread_buf[I3]);
}
#endif
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
// softmax
blockwise_softmax
.
RunWithPreCalcStats
(
acc_thread_buf
,
lse_thread_buf
);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1
)
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4
)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
acc_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
printf("tid %zd, P[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x,
acc_thread_buf[I0],
acc_thread_buf[I1],
acc_thread_buf[I2],
acc_thread_buf[I3]);
}
#endif
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds();
SubThreadBlock
<
BlockSize
>
p_thread_copy_subgroup
(
blockwise_gemm
.
GetWaveIdx
()[
I0
],
blockwise_gemm
.
GetWaveIdx
()[
I1
]);
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf
);
static_assert
(
sfc_p_m0_n0_m1_n1_m2_n2
.
GetNumOfAccess
()
==
num_vgrad_gemm_loop
,
""
);
block_sync_lds
();
vgrad_acc_thread_buf
.
Clear
();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
// TODO ANT: single buffer prefetch pipeline
static_for
<
0
,
num_vgrad_gemm_loop
,
1
>
{}([
&
](
auto
vgrad_gemm_loop_idx
)
{
// gemm dV
// load VGrad Gemm B
ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
});
}
// tail
// load VGrad Gemm A
const
auto
p_nd_idx
=
sfc_p_m0_n0_m1_n1_m2_n2
.
GetIndexTupleOfNumber
(
vgrad_gemm_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
p_nd_idx
[
I2
],
p_nd_idx
[
I2
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I2
]);
constexpr
auto
nwave_range
=
make_tuple
(
p_nd_idx
[
I3
],
p_nd_idx
[
I3
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I3
]);
if
(
p_thread_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
a1_blockwise_copy
.Run(
acc
_thread_desc_
k0_m_k1
,
p_thread_copy_vgpr_to_lds
.
Run
(
p
_thread_desc_
m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}
, I0, I0),
p_nd_idx
[
I0
],
p_nd_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
acc_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple(I0, I0, I0),
a1_thread_buf);
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
p_block_buf
);
}
block_sync_lds();
// ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// p slice window is moved by loop index
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
ygrad_block_slice_copy_step
);
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds
();
// sync before write
ygrad_blockwise_copy
.
RunWrite
(
ygrad_block_desc_m0_o_m1
,
ygrad_block_buf
);
#if 0
if (hipBlockIdx_x == 0)
{
debug::print_shared(
p_block_buf.p_data_,
index_t(p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize()));
}
#endif
#if 0
if (hipBlockIdx_x == 0)
{
debug::print_shared(ygrad_block_buf.p_data_,
index_t(ygrad_block_desc_m0_o_m1.GetElementSpaceSize()));
}
} // end gemm1
#endif
block_sync_lds
();
// sync before read
vgrad_blockwise_gemm
.
Run
(
p_block_buf
,
ygrad_block_buf
,
vgrad_acc_thread_buf
);
// workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128)
#if 1
if
(
hipBlockIdx_x
==
0
&&
hipThreadIdx_x
%
32
<
4
)
{
__builtin_amdgcn_sched_barrier(0);
printf
(
"outer %d inner %d tid %zd, dV[0:3] = %f, %f, %f, %f
\n
"
,
gemm1_k_block_outer_index
,
vgrad_gemm_loop_idx
.
value
,
hipThreadIdx_x
,
vgrad_acc_thread_buf
[
I0
],
vgrad_acc_thread_buf
[
I1
],
vgrad_acc_thread_buf
[
I2
],
vgrad_acc_thread_buf
[
I3
]);
}
#endif
});
// end gemm dV
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new
});
});
// atomic_add vgrad
vgrad_thread_copy_vgpr_to_global
.
Run
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
vgrad_acc_thread_buf
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_ak0_m_ak1
,
a_block_reset_copy_step
);
// rewind K
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_bk0_n_bk1
,
b_block_reset_copy_step
);
// rewind K and step N
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
ygrad_block_reset_copy_step
);
// rewind M
vgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_block_slice_copy_step
);
// step N
// update before next j iteration
running_max = running_max_new;
running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
#endif
// TODO ANT:
// shuffle dQ and write
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
b67a58c0
...
...
@@ -137,6 +137,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
// Sequence<num_access, idx_1d.value, i.value, src_offset>{}.foo();
SrcData
v
;
...
...
include/ck/utility/thread_group.hpp
View file @
b67a58c0
...
...
@@ -28,8 +28,8 @@ struct SubThreadBlock
__device__
static
constexpr
index_t
GetNumOfThread
()
{
return
kNumThread_
;
}
template
<
typename
Tuple2
>
__device__
constexpr
bool
IsBelong
(
const
Tuple
2
&
mwave_range
,
const
Tuple2
&
nwave_range
)
template
<
typename
Tuple
Arg1
,
typename
TupleArg
2
>
__device__
constexpr
bool
IsBelong
(
const
Tuple
Arg1
&
mwave_range
,
const
Tuple
Arg
2
&
nwave_range
)
{
// wave_range[I0] inclusive, wave_range[I1] exclusive
if
(
mwave_
<
mwave_range
[
I0
])
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
View file @
b67a58c0
...
...
@@ -149,6 +149,13 @@ struct ReferenceSoftmax : public device::BaseOperator
ck
::
type_convert
<
AccDataType
>
(
arg
.
sm_stats_ptr_
[
0
](
to_sm_stats_idx
(
idx
))))
+
arg
.
beta_
*
self
(
idx
);
// printf(
// "exponent %f, exp() = %f\n",
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx))),
// std::exp(
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))));
});
return
0
;
...
...
library/include/ck/library/utility/check_err.hpp
View file @
b67a58c0
...
...
@@ -148,7 +148,7 @@ check_err(const Range& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
32
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
test/space_filling_curve/CMakeLists.txt
View file @
b67a58c0
add_test_executable
(
test_space_filling_curve space_filling_curve.cpp
)
add_test_executable
(
test_threadwise_copy test_threadwise_copy.cpp
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment