Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b67a58c0
Commit
b67a58c0
authored
Nov 29, 2022
by
Anthony Chang
Browse files
can validate dV with relaxed error tolerance
parent
8551dd43
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
513 additions
and
264 deletions
+513
-264
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+52
-57
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+18
-2
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+18
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+88
-10
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+326
-191
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+1
-0
include/ck/utility/thread_group.hpp
include/ck/utility/thread_group.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
...rary/reference_tensor_operation/cpu/reference_softmax.hpp
+7
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+1
-1
test/space_filling_curve/CMakeLists.txt
test/space_filling_curve/CMakeLists.txt
+0
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
b67a58c0
...
@@ -15,7 +15,9 @@ Outputs:
...
@@ -15,7 +15,9 @@ Outputs:
*/
*/
#define PRINT_HOST 1
#pragma clang diagnostic ignored "-Wunused-variable"
#define PRINT_HOST 0
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -50,6 +52,7 @@ using YElementOp = PassThrough;
...
@@ -50,6 +52,7 @@ using YElementOp = PassThrough;
using
DataType
=
F16
;
using
DataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -76,6 +79,7 @@ using DeviceGemmInstance =
...
@@ -76,6 +79,7 @@ using DeviceGemmInstance =
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
DataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
AccDataType
,
AccDataType
,
...
@@ -172,14 +176,16 @@ template <typename TensorQ,
...
@@ -172,14 +176,16 @@ template <typename TensorQ,
typename
TensorV
,
typename
TensorV
,
typename
TensorS
,
typename
TensorS
,
typename
TensorP
,
typename
TensorP
,
typename
TensorY
>
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorK
&
k_g_n_k
,
const
TensorV
&
v_g_n_o
,
const
TensorV
&
v_g_n_o
,
const
float
alpha
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
)
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
)
{
{
// S = alpha * Q * K^T
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
...
@@ -207,7 +213,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
...
@@ -207,7 +213,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
});
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
}
,
&
lse_g_m
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
...
@@ -230,10 +236,10 @@ int run(int argc, char* argv[])
...
@@ -230,10 +236,10 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
4
;
ck
::
index_t
M
=
256
;
ck
::
index_t
N
=
4
;
ck
::
index_t
N
=
256
;
ck
::
index_t
K
=
4
;
ck
::
index_t
K
=
256
;
ck
::
index_t
O
=
4
;
ck
::
index_t
O
=
256
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
G1
=
1
;
ck
::
index_t
G1
=
1
;
...
@@ -242,8 +248,6 @@ int run(int argc, char* argv[])
...
@@ -242,8 +248,6 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
// use default case
// use default case
...
@@ -283,6 +287,8 @@ int run(int argc, char* argv[])
...
@@ -283,6 +287,8 @@ int run(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
...
@@ -307,16 +313,27 @@ int run(int argc, char* argv[])
...
@@ -307,16 +313,27 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -340,19 +357,20 @@ int run(int argc, char* argv[])
...
@@ -340,19 +357,20 @@ int run(int argc, char* argv[])
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
break
;
break
;
default:
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_
Diagonal
<
DataType
>
{});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_
1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_
1
<
DataType
>
{
10
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_
Sequential
<
2
>
{});
// dy[g0, g1, m, n] = m
}
}
// calculate y beforehand
// calculate y
& log-sum-exp
beforehand
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
...
@@ -360,27 +378,36 @@ int run(int argc, char* argv[])
...
@@ -360,27 +378,36 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
);
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
);
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
// TODO ANT: make sure K/V gradients are zeroed
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
// TODO ANT: attention backward kernel
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
...
@@ -388,6 +415,7 @@ int run(int argc, char* argv[])
...
@@ -388,6 +415,7 @@ int run(int argc, char* argv[])
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
...
@@ -402,6 +430,7 @@ int run(int argc, char* argv[])
...
@@ -402,6 +430,7 @@ int run(int argc, char* argv[])
v_gs_os_ns_strides
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
@@ -421,6 +450,7 @@ int run(int argc, char* argv[])
...
@@ -421,6 +450,7 @@ int run(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
// TODO ANT: add dQ/dK/dV flops & bytes
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
std
::
size_t
num_btype
=
(
sizeof
(
DataType
)
*
M
*
K
+
sizeof
(
DataType
)
*
K
*
N
+
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
sizeof
(
DataType
)
*
N
*
O
+
sizeof
(
DataType
)
*
M
*
O
)
*
...
@@ -445,7 +475,7 @@ int run(int argc, char* argv[])
...
@@ -445,7 +475,7 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
*
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
if
(
PRINT_HOST
)
if
(
PRINT_HOST
)
...
@@ -456,44 +486,6 @@ int run(int argc, char* argv[])
...
@@ -456,44 +486,6 @@ int run(int argc, char* argv[])
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
}
}
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
q_g_m_k
,
k_g_k_n
,
s_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
#if 0
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
#endif
// P = Softmax(S)
// >>> scipy.special.softmax(numpy.eye(4), 1)
// array([[0.47536689, 0.1748777 , 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.47536689, 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.47536689, 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_argument
=
ref_softmax
.
MakeArgument
(
s_g_m_n
,
p_g_m_n
,
1
,
0
,
{
2
});
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// Y = P * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// Gradients
// Gradients
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm_grad
=
ReferenceGemmGradInstance
{};
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
...
@@ -613,7 +605,10 @@ int run(int argc, char* argv[])
...
@@ -613,7 +605,10 @@ int run(int argc, char* argv[])
kgrad_gs_ns_ks_host_result
.
mData
);
kgrad_gs_ns_ks_host_result
.
mData
);
std
::
cout
<<
"Checking vgrad:
\n
"
;
std
::
cout
<<
"Checking vgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
);
vgrad_gs_os_ns_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
}
}
return
pass
?
0
:
1
;
return
pass
?
0
:
1
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
b67a58c0
...
@@ -50,7 +50,8 @@ template <index_t BlockSize,
...
@@ -50,7 +50,8 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
,
bool
TransposeC
=
false
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
...
@@ -226,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -226,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
View file @
b67a58c0
...
@@ -108,6 +108,24 @@ struct BlockwiseSoftmax
...
@@ -108,6 +108,24 @@ struct BlockwiseSoftmax
});
});
}
}
template
<
typename
CThreadBuffer
,
typename
LSEBuffer
>
__host__
__device__
void
RunWithPreCalcStats
(
CThreadBuffer
&
in_thread_buf
,
const
LSEBuffer
&
lse_thread_buf
)
{
// calculate exp for elements using pre-calculated stats LSE (log-sum-exp)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
IgnoreNaN
&&
ck
::
math
::
isnan
(
in_thread_buf
[
offset
])
?
0
:
math
::
exp
(
in_thread_buf
[
offset
]
-
lse_thread_buf
[
iM
]);
});
});
}
BufferType
max_value_buf
;
BufferType
max_value_buf
;
BufferType
sum_value_buf
;
BufferType
sum_value_buf
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
b67a58c0
...
@@ -18,12 +18,15 @@
...
@@ -18,12 +18,15 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
DataType
,
typename
DataType
,
typename
LSEDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
...
@@ -33,6 +36,7 @@ template <typename GridwiseGemm,
...
@@ -33,6 +36,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
LSEGridDescriptor_M
,
typename
VGradGridDescriptor_N_O
,
typename
VGradGridDescriptor_N_O
,
typename
YGradGridDesc_M0_O_M1
,
typename
YGradGridDesc_M0_O_M1
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
...
@@ -48,6 +52,7 @@ __global__ void
...
@@ -48,6 +52,7 @@ __global__ void
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
...
@@ -62,6 +67,7 @@ __global__ void
...
@@ -62,6 +67,7 @@ __global__ void
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
// const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args
// const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args
// const KGradGridDescriptor_N_K kgrad_grid_desc_n_k,
// const KGradGridDescriptor_N_K kgrad_grid_desc_n_k,
const
VGradGridDescriptor_N_O
vgrad_grid_desc_n_o
,
const
VGradGridDescriptor_N_O
vgrad_grid_desc_n_o
,
...
@@ -87,11 +93,14 @@ __global__ void
...
@@ -87,11 +93,14 @@ __global__ void
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
...
@@ -106,6 +115,7 @@ __global__ void
...
@@ -106,6 +115,7 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
vgrad_grid_desc_n_o
,
vgrad_grid_desc_n_o
,
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
block_2_ctile_map
,
...
@@ -140,6 +150,7 @@ template <index_t NumDimG,
...
@@ -140,6 +150,7 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
DataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
Acc1BiasDataType
,
typename
GemmAccDataType
,
typename
GemmAccDataType
,
...
@@ -350,9 +361,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -350,9 +361,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
.
second
;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ", v_gs_os_ns_strides_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec, ",") << std::endl;
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
return
PadTensorDescriptor
(
vgrad_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
Pad
M
,
padder
.
PadO
>
{});
Sequence
<
padder
.
Pad
N
,
padder
.
PadO
>
{});
}
}
template
<
typename
YGridDesc_M_O
,
typename
Number
>
template
<
typename
YGridDesc_M_O
,
typename
Number
>
...
@@ -415,10 +430,36 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -415,10 +430,36 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths_vec
,
k_gs_ns_ks_strides_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
k_gs_ns_ks_lengths_vec
,
k_gs_ns_ks_strides_vec
);
}
}
static
auto
MakeLSEGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
lse_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M
return
transform_tensor_descriptor
(
lse_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad M
return
lse_grid_desc_mraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
...
@@ -446,11 +487,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -446,11 +487,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
BatchStrideLSE_
(
BatchStrideLSE
)
{
{
}
}
...
@@ -474,16 +517,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -474,16 +517,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
}
__host__
__device__
constexpr
long_index_t
GetLSEBasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
}
private:
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
index_t
BatchStrideLSE_
;
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
DataType
,
// TODO: distinguish A/B datatype
DataType
,
// TODO: distinguish A/B datatype
LSEDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -496,6 +546,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -496,6 +546,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
LSEGridDesc_M
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
...
@@ -552,6 +603,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -552,6 +603,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
DataType
*
p_b_grid
,
const
DataType
*
p_b_grid
,
const
DataType
*
p_b1_grid
,
const
DataType
*
p_b1_grid
,
const
DataType
*
p_c_grid
,
// for dS
const
DataType
*
p_c_grid
,
// for dS
const
LSEDataType
*
p_lse_grid
,
const
DataType
*
p_ygrad_grid
,
const
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
DataType
*
p_kgrad_grid
,
...
@@ -566,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -566,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
...
@@ -581,6 +634,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -581,6 +634,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
p_lse_grid_
{
p_lse_grid
},
p_ygrad_grid_
{
p_ygrad_grid
},
p_qgrad_grid_
{
p_qgrad_grid
},
p_kgrad_grid_
{
p_kgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
p_vgrad_grid_
{
p_vgrad_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
...
@@ -590,9 +647,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -590,9 +647,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
// dV = P^T * dY
// dV = P^T * dY
vgrad_grid_desc_n_o_
{
DeviceOp
::
MakeVGradGridDescriptor_N_O
(
c_gs_ms_gemm1ns_lengths
,
vgrad_grid_desc_n_o_
{
DeviceOp
::
MakeVGradGridDescriptor_N_O
(
c_gs_m
s_gemm1
n
s_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1n
s_gemm1
k
s_strides
)},
/* PTrans descriptor will be constructed in kernel */
/* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_
{
ygrad_grid_desc_m0_o_m1_
{
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
c_grid_desc_m_n_
,
Number
<
Y_M1
>
{})},
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
c_grid_desc_m_n_
,
Number
<
Y_M1
>
{})},
...
@@ -627,7 +685,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -627,7 +685,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_base_ptr_of_batch_
{
compute_base_ptr_of_batch_
{
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
}
a_grid_desc_g_m_k_
,
b_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())
}
{
{
// TODO ANT: implement bias addition
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc0_biases
;
...
@@ -647,6 +705,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -647,6 +705,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
}
}
Print
();
}
}
void
Print
()
const
void
Print
()
const
...
@@ -659,14 +718,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -659,14 +718,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_
n_k
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"b1_grid_desc_g_
o_n
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
// b1_grid_desc_g_n_k_.Print();
std
::
cout
<<
"c_grid_desc_g_m_
n
_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"c_grid_desc_g_m_
o
_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
// c_grid_desc_g_m_n_.Print();
// c_grid_desc_g_m_n_.Print();
std
::
cout
<<
"vgrad_grid_desc_n_o_: "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I0
)
<<
", "
<<
vgrad_grid_desc_n_o_
.
GetLength
(
I1
)
<<
'\n'
;
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
}
}
// pointers
// pointers
...
@@ -674,6 +737,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -674,6 +737,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
DataType
*
p_b_grid_
;
const
DataType
*
p_b_grid_
;
const
DataType
*
p_b1_grid_
;
const
DataType
*
p_b1_grid_
;
const
DataType
*
p_c_grid_
;
const
DataType
*
p_c_grid_
;
const
LSEDataType
*
p_lse_grid_
;
const
DataType
*
p_ygrad_grid_
;
const
DataType
*
p_ygrad_grid_
;
DataType
*
p_vgrad_grid_
;
DataType
*
p_vgrad_grid_
;
DataType
*
p_qgrad_grid_
;
DataType
*
p_qgrad_grid_
;
...
@@ -684,6 +748,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -684,6 +748,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
LSEGridDesc_M
lse_grid_desc_m_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
...
@@ -732,7 +797,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -732,7 +797,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
std
::
cout
<<
"grid size = "
<<
grid_size
<<
'\n'
;
// Gemm0_K
// Gemm0_K
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
@@ -743,6 +808,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -743,6 +808,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
DataType
,
DataType
,
LSEDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
...
@@ -752,6 +818,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -752,6 +818,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
LSEGridDesc_M
,
DeviceOp
::
VGradGridDesc_N_O
,
DeviceOp
::
VGradGridDesc_N_O
,
DeviceOp
::
YGradGridDesc_M0_O_M1
,
DeviceOp
::
YGradGridDesc_M0_O_M1
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
...
@@ -768,6 +835,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -768,6 +835,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_ygrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_qgrad_grid_
,
arg
.
p_kgrad_grid_
,
arg
.
p_kgrad_grid_
,
...
@@ -781,6 +849,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -781,6 +849,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
vgrad_grid_desc_n_o_
,
arg
.
vgrad_grid_desc_n_o_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
...
@@ -791,6 +860,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -791,6 +860,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
// to concern Gemm0's loop
#if 1
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
...
@@ -799,7 +869,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -799,7 +869,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
#endif
return
ave_time
;
return
ave_time
;
}
}
...
@@ -898,6 +968,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -898,6 +968,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
DataType
*
p_b
,
const
DataType
*
p_b
,
const
DataType
*
p_b1
,
const
DataType
*
p_b1
,
const
DataType
*
p_c
,
const
DataType
*
p_c
,
const
LSEDataType
*
p_lse
,
const
DataType
*
p_ygrad_grid
,
const
DataType
*
p_ygrad_grid
,
DataType
*
p_qgrad_grid
,
DataType
*
p_qgrad_grid
,
DataType
*
p_kgrad_grid
,
DataType
*
p_kgrad_grid
,
...
@@ -912,6 +983,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -912,6 +983,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
...
@@ -928,6 +1000,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -928,6 +1000,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b
,
p_b
,
p_b1
,
p_b1
,
p_c
,
p_c
,
p_lse
,
p_ygrad_grid
,
p_ygrad_grid
,
p_qgrad_grid
,
p_qgrad_grid
,
p_kgrad_grid
,
p_kgrad_grid
,
...
@@ -942,6 +1015,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -942,6 +1015,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
...
@@ -962,6 +1036,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -962,6 +1036,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
void
*
p_b
,
const
void
*
p_b
,
const
void
*
p_b1
,
const
void
*
p_b1
,
const
void
*
p_c
,
const
void
*
p_c
,
const
void
*
p_lse
,
const
void
*
p_ygrad_grid
,
const
void
*
p_ygrad_grid
,
void
*
p_qgrad_grid
,
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_kgrad_grid
,
...
@@ -976,6 +1051,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -976,6 +1051,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
...
@@ -992,6 +1068,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -992,6 +1068,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
const
DataType
*>
(
p_b
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_b1
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
DataType
*>
(
p_c
),
static_cast
<
const
LSEDataType
*>
(
p_lse
),
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
static_cast
<
const
DataType
*>
(
p_ygrad_grid
),
static_cast
<
DataType
*>
(
p_qgrad_grid
),
static_cast
<
DataType
*>
(
p_qgrad_grid
),
static_cast
<
DataType
*>
(
p_kgrad_grid
),
static_cast
<
DataType
*>
(
p_kgrad_grid
),
...
@@ -1006,6 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -1006,6 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
,
acc0_biases_gs_ms_ns_strides
,
acc1_biases_gs_ms_gemm1ns_lengths
,
acc1_biases_gs_ms_gemm1ns_lengths
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
b67a58c0
...
@@ -21,6 +21,7 @@ namespace ck {
...
@@ -21,6 +21,7 @@ namespace ck {
template
<
typename
DataType
,
template
<
typename
DataType
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
FloatLSE
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
...
@@ -31,6 +32,7 @@ template <typename DataType,
...
@@ -31,6 +32,7 @@ template <typename DataType,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -170,7 +172,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -170,7 +172,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
index_t
n
=
Free0_N
-
1
;
constexpr
index_t
n
=
Free0_N
-
1
;
constexpr
index_t
n2
=
n
%
NPerXdl
;
constexpr
index_t
n2
=
n
%
NPerXdl
;
constexpr
index_t
n1
=
n
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
n1
=
n
/
NPerXdl
%
Gemm0NWaves
;
constexpr
index_t
n0
=
n
/
NPerXdl
/
Gemm0NWaves
%
M
XdlPerWave
;
constexpr
index_t
n0
=
n
/
NPerXdl
/
Gemm0NWaves
%
N
XdlPerWave
;
// assume 256 decomposed into 2 x 4 x 32
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
...
@@ -178,6 +180,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -178,6 +180,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
>
{}
+
Sequence
<
1
,
1
,
1
,
1
,
1
,
1
>
{};
return
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
>
{}
+
Sequence
<
1
,
1
,
1
,
1
,
1
,
1
>
{};
}
}
__host__
__device__
static
constexpr
auto
GetPBlockSliceLengths_M0_N0_M1_N1
()
{
return
generate_sequence_v2
(
[](
auto
I
)
{
return
GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2
().
At
(
I
);
},
Number
<
4
>
{});
}
// template <typename PBlockDesc_M0_N_M1>
// template <typename PBlockDesc_M0_N_M1>
// __host__ __device__ static constexpr auto
// __host__ __device__ static constexpr auto
// MakePMmaTileDescriptor_N0_N1_N2_M(const PBlockDesc_M0_N_M1&)
// MakePMmaTileDescriptor_N0_N1_N2_M(const PBlockDesc_M0_N_M1&)
...
@@ -303,13 +312,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -303,13 +312,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
index_t
gemm1_bytes_end
=
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
DataType
);
sizeof
(
DataType
);
const
index_t
vgrad_gemm_bytes_end
=
(
SharedMemTrait
::
p_block_space_size_aligned
+
SharedMemTrait
::
ygrad_block_space_size_aligned
)
*
sizeof
(
DataType
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatGemmAcc
);
sizeof
(
FloatGemmAcc
);
const
index_t
c_block_bytes_end
=
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
vgrad_gemm_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
@@ -395,6 +412,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -395,6 +412,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
}
__host__
__device__
static
constexpr
auto
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
const
LSEGridDesc_M
&
lse_grid_desc_m
)
{
const
index_t
M
=
lse_grid_desc_m
.
GetLength
(
I0
);
const
index_t
MBlock
=
M
/
MPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
const
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
transform_tensor_descriptor
(
lse_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MXdlPerWave
>
{},
MWave
,
Number
<
MPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}));
return
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
...
@@ -418,8 +451,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -418,8 +451,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
static
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
auto
p_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
static
constexpr
auto
ygrad_block_desc_m0_o_m1
=
VGradGemmTile_N_O_M
::
GetYGradBlockDescriptor_M0_O_M1
();
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
B1K1
)
;
static
constexpr
auto
max_lds_align
=
Number
<
16
/
sizeof
(
DataType
)
>
{}
;
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
...
@@ -427,10 +464,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -427,10 +464,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
p_block_space_size_aligned
=
math
::
integer_least_multiple
(
p_block_desc_m0_n_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
ygrad_block_space_size_aligned
=
math
::
integer_least_multiple
(
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
b1_block_space_offset
=
0
;
static
constexpr
auto
p_block_space_offset
=
0
;
static
constexpr
auto
ygrad_block_space_offset
=
p_block_space_size_aligned
.
value
;
// LDS allocation for reduction
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
static
constexpr
index_t
reduction_space_size_aligned
=
...
@@ -454,6 +497,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -454,6 +497,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_b1_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
DataType
*
__restrict__
p_c_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_qgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
DataType
*
__restrict__
p_kgrad_grid
,
...
@@ -469,6 +513,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -469,6 +513,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
VGradGridDescriptor_N_O
&
vgrad_grid_desc_n_o
,
const
VGradGridDescriptor_N_O
&
vgrad_grid_desc_n_o
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
...
@@ -482,6 +527,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -482,6 +527,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
p_b1_grid
,
b1_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
vgrad_grid_desc_n_o
.
GetElementSpaceSize
());
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -830,6 +881,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -830,6 +881,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
lse_grid_desc_m
);
constexpr
auto
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
m0
,
m1
,
m2
));
auto
lse_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatLSE
>
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
auto
acc0_thread_origin
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex8D
(
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{},
Number
<
0
>
{});
auto
lse_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatLSE
,
FloatLSE
,
decltype
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
),
Sequence
<
1
,
m0
,
m1
,
m2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
m2
,
1
,
false
>
{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
//
//
// dV
// dV
//
//
...
@@ -837,11 +917,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -837,11 +917,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1
// P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1
// m0, n0 are m/n repeat per wave
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
// m1, n1 are number of waves
constexpr
auto
p_
src_
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
constexpr
auto
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
p_dst_block_desc_m0_n_m1
=
constexpr
auto
p_block_desc_m0_n_m1
=
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
VGradGemmTile_N_O_M
::
GetPBlockDescriptor_M0_N_M1
();
constexpr
auto
p_block_lengths
=
constexpr
auto
p_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
...
@@ -854,30 +933,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -854,30 +933,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
P_N3
=
p_block_lengths
[
I6
];
constexpr
auto
P_N3
=
p_block_lengths
[
I6
];
constexpr
auto
P_N4
=
p_block_lengths
[
I7
];
constexpr
auto
P_N4
=
p_block_lengths
[
I7
];
constexpr
auto
p_
dst_
block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
[
&
]()
constexpr
constexpr
auto
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
[
&
]()
constexpr
{
{
constexpr
auto
p_
dst_
block_desc_m_n
=
transform_tensor_descriptor
(
constexpr
auto
p_block_desc_m_n
=
transform_tensor_descriptor
(
p_
dst_
block_desc_m0_n_m1
,
p_block_desc_m0_n_m1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
VGradGemmTile_N_O_M
::
P_M0
,
VGradGemmTile_N_O_M
::
P_M1
)),
make_tuple
(
VGradGemmTile_N_O_M
::
P_M0
,
VGradGemmTile_N_O_M
::
P_M1
)),
make_pass_through_transform
(
VGradGemmTile_N_O_M
::
Free0_N
)),
make_pass_through_transform
(
VGradGemmTile_N_O_M
::
Free0_N
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
p_
dst_
block_desc_m_n
,
p_block_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
P_M0
,
P_M1
,
P_M2
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
P_M1
,
P_M2
)),
make_unmerge_transform
(
make_tuple
(
P_N0
,
P_N1
,
P_N2
,
P_N3
,
P_N4
))),
make_unmerge_transform
(
make_tuple
(
I1
,
P_N1
,
P_N2
,
P_N3
,
P_N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
}
}
();
();
// TODO ANT: check lds offset
const
auto
p_thread_origin_nd_idx_on_block
=
[
&
]()
{
auto
p_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
),
p_dst_block_desc_m0_n_m1
.
GetElementSpaceSize
());
const
auto
p_dst_thread_origin
=
[
&
]()
{
const
auto
c_thread_mtx_on_block
=
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
...
@@ -904,8 +981,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -904,8 +981,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
.
CalculateBottomIndex
(
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
make_multi_index
(
n_thread_data_on_block
));
return
make_tuple
(
0
,
// mrepeat
return
make_tuple
(
m_thread_data_on_block_idx
[
I0
],
// mrepeat
0
,
// nrepeat
n_thread_data_on_block_idx
[
I0
],
// nrepeat
m_thread_data_on_block_idx
[
I1
],
// mwave
m_thread_data_on_block_idx
[
I1
],
// mwave
n_thread_data_on_block_idx
[
I1
],
// nwave
n_thread_data_on_block_idx
[
I1
],
// nwave
m_thread_data_on_block_idx
[
I2
],
// xdlops
m_thread_data_on_block_idx
[
I2
],
// xdlops
...
@@ -914,19 +991,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -914,19 +991,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
n_thread_data_on_block_idx
[
I4
]);
n_thread_data_on_block_idx
[
I4
]);
}();
}();
constexpr
auto
p_block_slice_lengths_m0_n0_m1_n1
_m2_n2
=
// mrepeat, nrepeat, mwaves,
constexpr
auto
p_block_slice_lengths_m0_n0_m1_n1
=
// nwaves, mperxdl, nperxdl
VGradGemmTile_N_O_M
::
GetPBlockSliceLengths_M0_N0_M1_N1
();
// mrepeat, nrepeat,
VGradGemmTile_N_O_M
::
GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2
();
// mwaves, nwaves,
// how to properly perform copy for a sub-workgroup?
// how to properly perform copy for a sub-workgroup?
auto
p_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
p_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
DataType
,
decltype
(
p_
src_
thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
p_
dst_
block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
p_block_slice_lengths_m0_n0_m1_n1
_m2_n2
[
I0
],
// ThreadSliceLengths
Sequence
<
p_block_slice_lengths_m0_n0_m1_n1
[
I0
],
// ThreadSliceLengths
p_block_slice_lengths_m0_n0_m1_n1
_m2_n2
[
I1
],
p_block_slice_lengths_m0_n0_m1_n1
[
I1
],
I1
,
I1
,
I1
,
I1
,
I1
,
I1
,
...
@@ -939,21 +1016,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -939,21 +1016,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
true
>
{
p_dst_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
p_dst_thread_origin
[
I0
],
make_multi_index
(
p_dst_thread_origin
[
I1
],
p_thread_origin_nd_idx_on_block
[
I0
],
p_dst_thread_origin
[
I2
]
%
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I2
],
p_thread_origin_nd_idx_on_block
[
I1
],
p_dst_thread_origin
[
I3
]
%
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I3
],
p_thread_origin_nd_idx_on_block
[
I2
]
%
p_block_slice_lengths_m0_n0_m1_n1
[
I2
],
p_dst_thread_origin
[
I4
],
p_thread_origin_nd_idx_on_block
[
I3
]
%
p_block_slice_lengths_m0_n0_m1_n1
[
I3
],
p_dst_thread_origin
[
I5
],
p_thread_origin_nd_idx_on_block
[
I4
],
p_dst_thread_origin
[
I6
],
p_thread_origin_nd_idx_on_block
[
I5
],
p_dst_thread_origin
[
I7
]),
p_thread_origin_nd_idx_on_block
[
I6
],
p_thread_origin_nd_idx_on_block
[
I7
]),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
// construct space filling curve
// Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0],
// p_thread_copy_vgpr_to_lds.Run();
// p_block_slice_lengths_m0_n0_m1_n1[I1],
// I1,
constexpr
auto
ygrad_dst_block_desc_m0_o_m1
=
// I1,
// I1,
// P_N2,
// I1,
// P_N4>{}
// .foo();
// 1, 4, 1, 1, 1, 4, 1, 4
constexpr
auto
sfc_p_m0_n0_m1_n1_m2_n2
=
SpaceFillingCurve
<
Sequence
<
P_M0
,
P_N0
,
P_M1
,
P_N1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
decltype
(
p_block_slice_lengths_m0_n0_m1_n1
)
>
{};
constexpr
auto
ygrad_block_desc_m0_o_m1
=
VGradGemmTile_N_O_M
::
GetYGradBlockDescriptor_M0_O_M1
();
VGradGemmTile_N_O_M
::
GetYGradBlockDescriptor_M0_O_M1
();
auto
ygrad_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
auto
ygrad_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
...
@@ -967,7 +1058,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -967,7 +1058,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
DataType
,
DataType
,
DataType
,
DataType
,
decltype
(
ygrad_grid_desc_m0_o_m1
),
decltype
(
ygrad_grid_desc_m0_o_m1
),
decltype
(
ygrad_
dst_
block_desc_m0_o_m1
),
decltype
(
ygrad_block_desc_m0_o_m1
),
typename
VGradGemmTile_N_O_M
::
YGrad_ThreadClusterArrangeOrder
,
// access order == thread
typename
VGradGemmTile_N_O_M
::
YGrad_ThreadClusterArrangeOrder
,
// access order == thread
// order
// order
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
...
@@ -980,113 +1071,165 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -980,113 +1071,165 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true
,
true
,
true
,
true
,
1
>
(
ygrad_grid_desc_m0_o_m1
,
1
>
(
ygrad_grid_desc_m0_o_m1
,
make_multi_index
(
0
,
gemm1_n_block_data_idx_on_grid
,
0
),
make_multi_index
(
m_block_data_idx_on_grid
/
VGradGemmTile_N_O_M
::
YGrad_M1
,
gemm1_n_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
ygrad_
dst_
block_desc_m0_o_m1
,
ygrad_block_desc_m0_o_m1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
p_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
p_block_space_offset
,
p_block_desc_m0_n_m1
.
GetElementSpaceSize
());
auto
ygrad_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
ygrad_block_space_offset
,
ygrad_block_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
auto
vgrad_blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockSize
,
DataType
,
DataType
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
p_
dst_
block_desc_m0_n_m1
),
decltype
(
p_block_desc_m0_n_m1
),
decltype
(
ygrad_
dst_
block_desc_m0_o_m1
),
decltype
(
ygrad_block_desc_m0_o_m1
),
MPerXdl
,
MPerXdl
,
NPerXdl
,
NPerXdl
,
VGradGemmTile_N_O_M
::
GemmNRepeat
,
// NRepeat
VGradGemmTile_N_O_M
::
GemmNRepeat
,
VGradGemmTile_N_O_M
::
GemmORepeat
,
// ORepeat
VGradGemmTile_N_O_M
::
GemmORepeat
,
VGradGemmTile_N_O_M
::
GemmMPack
>
{};
VGradGemmTile_N_O_M
::
GemmMPack
,
true
>
{};
// TranspossC
constexpr
auto
vgrad_block_lengths
=
auto
vgrad_acc_thread_buf
=
vgrad_blockwise_gemm
.
GetCThreadBuffer
();
vgrad_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
=
transform_tensor_descriptor
(
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
=
transform_tensor_descriptor
(
vgrad_grid_desc_n_o
,
vgrad_grid_desc_n_o
,
make_tuple
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
I1
,
// may place a dummy variable
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I2
],
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I4
])),
make_unmerge_transform
(
make_tuple
(
I1
,
make_unmerge_transform
(
make_tuple
(
I1
,
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I3
],
VGradGemmTile_N_O_M
::
GemmNWave
,
p_block_slice_lengths_m0_n0_m1_n1_m2_n2
[
I5
]))),
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
I1
,
VGradGemmTile_N_O_M
::
GemmOWave
,
NPerXdl
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
constexpr
auto
vgrad_thread_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
=
constexpr
auto
vgrad_thread_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
=
vgrad_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_
M3_M4
_N
2
();
vgrad_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_
N2_N3
_N
4
();
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
=
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
=
vgrad_blockwise_gemm
.
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_
M3_M4
_N
2
(
vgrad_blockwise_gemm
.
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_
N2_N3
_N
4
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
);
vgrad_grid_desc_n0_o0_n1_o1_n2_o2
);
const
auto
vgrad_thread_mtx_on_block_n_o
=
const
auto
vgrad_thread_mtx_on_block_n_o
=
vgrad_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
vgrad_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
=
constexpr
auto
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
=
decltype
(
vgrad_blockwise_gemm
)
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_
M3_M4
_N
2
();
decltype
(
vgrad_blockwise_gemm
)
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_
N2_N3
_N
4
();
constexpr
auto
VGrad_N0
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I0
);
constexpr
auto
VGrad_N0
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I0
);
constexpr
auto
VGrad_O0
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I1
);
constexpr
auto
VGrad_O0
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I1
);
constexpr
auto
VGrad_N1
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I2
);
constexpr
auto
VGrad_N1
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I2
);
constexpr
auto
VGrad_O1
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I3
);
constexpr
auto
VGrad_O1
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I3
);
constexpr
auto
VGrad_N2
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I4
);
constexpr
auto
VGrad_N2
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I4
);
constexpr
auto
VGrad_
N3
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I5
);
constexpr
auto
VGrad_
O2
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I5
);
constexpr
auto
VGrad_
N4
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I6
);
constexpr
auto
VGrad_
O3
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I6
);
constexpr
auto
VGrad_O
2
=
vgrad_block_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
.
GetLength
(
I7
);
constexpr
auto
VGrad_O
4
=
vgrad_block_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
.
GetLength
(
I7
);
const
index_t
n_thread_data_idx_on_grid
=
const
index_t
n_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I0
];
vgrad_thread_mtx_on_block_n_o
[
I0
];
// TODO ANT: step n after each Gemm1 outer loop
const
index_t
o_thread_data_idx_on_grid
=
const
index_t
o_thread_data_idx_on_grid
=
vgrad_thread_mtx_on_block_n_o
[
I1
]
+
gemm1_n_block_data_idx_on_grid
;
vgrad_thread_mtx_on_block_n_o
[
I1
]
+
gemm1_n_block_data_idx_on_grid
;
const
auto
n_thread_data_on_grid_to_n0_n1_n2_
n3_n4_
adaptor
=
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_N0
,
VGrad_N1
,
VGrad_N2
,
VGrad_N3
,
VGrad_N4
))),
make_tuple
(
VGrad_N0
,
VGrad_N1
,
VGrad_N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_nd_idx_on_grid
=
const
auto
n_thread_data_nd_idx_on_grid
=
n_thread_data_on_grid_to_n0_n1_n2_
n3_n4_
adaptor
.
CalculateBottomIndex
(
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_idx_on_grid
));
make_multi_index
(
n_thread_data_idx_on_grid
));
const
auto
o_thread_data_on_grid_to_o0_o1_o2_adaptor
=
make_single_stage_tensor_adaptor
(
const
auto
o_thread_data_on_grid_to_o0_o1_o2_o3_o4_adaptor
=
make_tuple
(
make_merge_transform
(
make_tuple
(
VGrad_O0
,
VGrad_O1
,
VGrad_O2
))),
make_single_stage_tensor_adaptor
(
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
make_merge_transform
(
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
VGrad_O0
,
VGrad_O1
,
VGrad_O2
,
VGrad_O3
,
VGrad_O4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
o_thread_data_nd_idx_on_grid
=
const
auto
o_thread_data_nd_idx_on_grid
=
o_thread_data_on_grid_to_o0_o1_o2_adaptor
.
CalculateBottomIndex
(
o_thread_data_on_grid_to_o0_o1_o2_
o3_o4_
adaptor
.
CalculateBottomIndex
(
make_multi_index
(
o_thread_data_idx_on_grid
));
make_multi_index
(
o_thread_data_idx_on_grid
));
auto
vgrad_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
vgrad_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
DataType
,
decltype
(
vgrad_thread_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
),
decltype
(
vgrad_thread_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
),
decltype
(
vgrad_grid_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
),
decltype
(
vgrad_grid_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
),
tensor_operation
::
element_wise
::
PassThrough
,
// CElementwiseOperation
tensor_operation
::
element_wise
::
PassThrough
,
// CElementwiseOperation
decltype
(
vgrad_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
decltype
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
.
GetLengths
()),
// SliceLengths
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
7
,
// VectorDim
1
,
// ScalarPerVector
2
,
// ScalarPerVector
InMemoryDataOperationEnum
::
AtomicAdd
,
// GlobalMemoryDataOperation
InMemoryDataOperationEnum
::
AtomicAdd
,
// GlobalMemoryDataOperation
1
,
1
,
// DstScalarStrideInVector
true
>
(
vgrad_grid_desc_n0_o0_n1_o1_n2_
n3_n4
_o
2
,
true
>
(
vgrad_grid_desc_n0_o0_n1_o1_n2_
o2_o3
_o
4
,
make_multi_index
(
n_thread_data_nd_idx_on_grid
[
I0
],
make_multi_index
(
n_thread_data_nd_idx_on_grid
[
I0
],
o_thread_data_nd_idx_on_grid
[
I0
],
o_thread_data_nd_idx_on_grid
[
I0
],
n_thread_data_nd_idx_on_grid
[
I1
],
n_thread_data_nd_idx_on_grid
[
I1
],
o_thread_data_nd_idx_on_grid
[
I1
],
o_thread_data_nd_idx_on_grid
[
I1
],
n_thread_data_nd_idx_on_grid
[
I2
],
n_thread_data_nd_idx_on_grid
[
I2
],
n
_thread_data_nd_idx_on_grid
[
I
3
],
o
_thread_data_nd_idx_on_grid
[
I
2
],
n
_thread_data_nd_idx_on_grid
[
I
4
],
o
_thread_data_nd_idx_on_grid
[
I
3
],
o_thread_data_nd_idx_on_grid
[
I
2
]),
o_thread_data_nd_idx_on_grid
[
I
4
]),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
// TODO ANT: ygrad slice window step size
#if 0
#if 0
if(hipThreadIdx_x % 32 < 4)
{
printf("wid %zd tid %zd _n0_o0_n1_o1_n2_o2_o3_o4 %d %d %d %d %d %d %d %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
n_thread_data_nd_idx_on_grid[I0],
o_thread_data_nd_idx_on_grid[I0],
n_thread_data_nd_idx_on_grid[I1],
o_thread_data_nd_idx_on_grid[I1],
n_thread_data_nd_idx_on_grid[I2],
o_thread_data_nd_idx_on_grid[I2],
o_thread_data_nd_idx_on_grid[I3],
o_thread_data_nd_idx_on_grid[I4]);
}
#endif
// p_thread_slice_copy_step will be in for loop
constexpr
auto
ygrad_block_slice_copy_step
=
make_multi_index
(
VGradGemmTile_N_O_M
::
YGrad_M0
,
0
,
0
);
constexpr
auto
ygrad_block_reset_copy_step
=
make_multi_index
(
-
MPerBlock
/
VGradGemmTile_N_O_M
::
YGrad_M1
,
0
,
0
);
// vgrad gemm output tile
const
auto
vgrad_block_slice_copy_step
=
make_multi_index
(
VGradGemmTile_N_O_M
::
GemmNRepeat
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
#if 0
if(hipThreadIdx_x == 0)
{
printf("bid %zd, n_grid = %d, o_grid = %d, step N0 = %d\n",
hipBlockIdx_x,
n_thread_data_idx_on_grid,
o_thread_data_idx_on_grid,
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(NPerBlock))[I0]);
}
#endif
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
VGradGemmTile_N_O_M
::
Sum_M
;
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_buf
,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
lse_thread_buf
);
// gemm1 K loop
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
do
do
...
@@ -1178,131 +1321,123 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1178,131 +1321,123 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// softmax
#if 0
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
{
printf("tid %zd, S[0:3] = %f, %f, %f, %f\n",
blockwise_softmax.Run(acc_thread_buf, workspace_buf);
hipThreadIdx_x,
acc_thread_buf[I0],
acc_thread_buf[I1],
acc_thread_buf[I2],
acc_thread_buf[I3]);
}
#endif
// TODO: may convert to log domain
// softmax
running_max_new = mathext::max(max, running_max);
blockwise_softmax
.
RunWithPreCalcStats
(
acc_thread_buf
,
lse_thread_buf
);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
// gemm1
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
{
// TODO: explore using dynamic buffer for a1 thread buffer
printf("tid %zd, P[0:3] = %f, %f, %f, %f\n",
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
hipThreadIdx_x,
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
acc_thread_buf[I0],
// the A1 source buffer is static buffer holding the output of first GEMM and
acc_thread_buf[I1],
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
acc_thread_buf[I2],
// explicitly in Run() below.
acc_thread_buf[I3]);
}
#endif
// Initialize acc1
block_sync_lds
();
// wait for gemm1 LDS read
acc1_thread_buf.Clear();
// preload data into LDS
SubThreadBlock
<
BlockSize
>
p_thread_copy_subgroup
(
blockwise_gemm
.
GetWaveIdx
()[
I0
],
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf
);
blockwise_gemm
.
GetWaveIdx
()[
I1
]
);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
static_assert
(
sfc_p_m0_n0_m1_n1_m2_n2
.
GetNumOfAccess
()
==
num_vgrad_gemm_loop
,
""
);
b1_block_slice_copy_step);
block_sync_lds(); // wait for reduction LDS read
vgrad_acc_thread_buf
.
Clear
();
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
// TODO ANT: single buffer prefetch pipeline
static_for
<
0
,
num_vgrad_gemm_loop
,
1
>
{}([
&
](
auto
vgrad_gemm_loop_idx
)
{
// gemm dV
// load VGrad Gemm B
ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
// main body
// load VGrad Gemm A
if constexpr(num_gemm1_k_block_inner_loop > 1)
const
auto
p_nd_idx
=
sfc_p_m0_n0_m1_n1_m2_n2
.
GetIndexTupleOfNumber
(
vgrad_gemm_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
p_nd_idx
[
I2
],
p_nd_idx
[
I2
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I2
]);
constexpr
auto
nwave_range
=
make_tuple
(
p_nd_idx
[
I3
],
p_nd_idx
[
I3
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I3
]);
if
(
p_thread_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
p_thread_copy_vgpr_to_lds
.
Run
(
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
make_tuple
(
acc_thread_buf,
p_nd_idx
[
I0
],
p_nd_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a1_thread_desc_k0_m_k1,
acc_thread_buf
,
make_tuple(I0, I0, I0),
p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
a1_thread_buf);
p_block_buf
);
}
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
// ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
b1_block_slice_copy_step);
// p slice window is moved by loop index
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
ygrad_block_slice_copy_step
);
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
block_sync_lds
();
// sync before write
});
ygrad_blockwise_copy
.
RunWrite
(
ygrad_block_desc_m0_o_m1
,
ygrad_block_buf
);
#if 0
if (hipBlockIdx_x == 0)
{
debug::print_shared(
p_block_buf.p_data_,
index_t(p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize()));
}
}
// tail
#endif
#if 0
if (hipBlockIdx_x == 0)
{
{
a1_blockwise_copy.Run(
debug::print_shared(ygrad_block_buf.p_data_,
acc_thread_desc_k0_m_k1,
index_t(ygrad_block_desc_m0_o_m1.GetElementSpaceSize()));
make_tuple(
}
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
#endif
acc_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds();
block_sync_lds
();
// sync before read
vgrad_blockwise_gemm
.
Run
(
p_block_buf
,
ygrad_block_buf
,
vgrad_acc_thread_buf
);
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
#if 1
if
(
hipBlockIdx_x
==
0
&&
hipThreadIdx_x
%
32
<
4
)
{
printf
(
"outer %d inner %d tid %zd, dV[0:3] = %f, %f, %f, %f
\n
"
,
gemm1_k_block_outer_index
,
vgrad_gemm_loop_idx
.
value
,
hipThreadIdx_x
,
vgrad_acc_thread_buf
[
I0
],
vgrad_acc_thread_buf
[
I1
],
vgrad_acc_thread_buf
[
I2
],
vgrad_acc_thread_buf
[
I3
]);
}
}
} // end gemm1
#endif
});
// end gemm dV
// workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128)
{
__builtin_amdgcn_sched_barrier(0);
}
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
// atomic_add vgrad
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
vgrad_thread_copy_vgpr_to_global
.
Run
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
vgrad_acc_thread_buf
,
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
vgrad_grid_buf
);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new
});
});
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_ak0_m_ak1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_ak0_m_ak1
,
a_block_reset_copy_step
);
// rewind K
a_block_reset_copy_step
);
// rewind K
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_bk0_n_bk1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_bk0_n_bk1
,
b_block_reset_copy_step
);
// rewind K and step N
b_block_reset_copy_step
);
// rewind K and step N
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
ygrad_block_reset_copy_step
);
// rewind M
vgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_block_slice_copy_step
);
// step N
// update before next j iteration
running_max = running_max_new;
running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
#endif
// TODO ANT:
// TODO ANT:
// shuffle dQ and write
// shuffle dQ and write
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
b67a58c0
...
@@ -137,6 +137,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -137,6 +137,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
// Sequence<num_access, idx_1d.value, i.value, src_offset>{}.foo();
SrcData
v
;
SrcData
v
;
...
...
include/ck/utility/thread_group.hpp
View file @
b67a58c0
...
@@ -28,8 +28,8 @@ struct SubThreadBlock
...
@@ -28,8 +28,8 @@ struct SubThreadBlock
__device__
static
constexpr
index_t
GetNumOfThread
()
{
return
kNumThread_
;
}
__device__
static
constexpr
index_t
GetNumOfThread
()
{
return
kNumThread_
;
}
template
<
typename
Tuple2
>
template
<
typename
Tuple
Arg1
,
typename
TupleArg
2
>
__device__
constexpr
bool
IsBelong
(
const
Tuple
2
&
mwave_range
,
const
Tuple2
&
nwave_range
)
__device__
constexpr
bool
IsBelong
(
const
Tuple
Arg1
&
mwave_range
,
const
Tuple
Arg
2
&
nwave_range
)
{
{
// wave_range[I0] inclusive, wave_range[I1] exclusive
// wave_range[I0] inclusive, wave_range[I1] exclusive
if
(
mwave_
<
mwave_range
[
I0
])
if
(
mwave_
<
mwave_range
[
I0
])
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
View file @
b67a58c0
...
@@ -149,6 +149,13 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -149,6 +149,13 @@ struct ReferenceSoftmax : public device::BaseOperator
ck
::
type_convert
<
AccDataType
>
(
ck
::
type_convert
<
AccDataType
>
(
arg
.
sm_stats_ptr_
[
0
](
to_sm_stats_idx
(
idx
))))
+
arg
.
sm_stats_ptr_
[
0
](
to_sm_stats_idx
(
idx
))))
+
arg
.
beta_
*
self
(
idx
);
arg
.
beta_
*
self
(
idx
);
// printf(
// "exponent %f, exp() = %f\n",
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx))),
// std::exp(
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))));
});
});
return
0
;
return
0
;
...
...
library/include/ck/library/utility/check_err.hpp
View file @
b67a58c0
...
@@ -148,7 +148,7 @@ check_err(const Range& out,
...
@@ -148,7 +148,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
32
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
test/space_filling_curve/CMakeLists.txt
View file @
b67a58c0
add_test_executable
(
test_space_filling_curve space_filling_curve.cpp
)
add_test_executable
(
test_space_filling_curve space_filling_curve.cpp
)
add_test_executable
(
test_threadwise_copy test_threadwise_copy.cpp
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment