Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
66052232
Commit
66052232
authored
Feb 13, 2023
by
danyao12
Browse files
sync attn-bwd-dropout
parents
5eb5e316
bf80ceee
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1814 additions
and
3219 deletions
+1814
-3219
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+4
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+153
-26
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16_dropout.cpp
...emm/batched_multihead_attention_backward_fp16_dropout.cpp
+0
-808
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
...softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
+5
-0
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+20
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
...ice_batched_multihead_attention_backward_xdl_cshuffle.hpp
+105
-11
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle
...l/device_grouped_multihead_attention_forward_xdl_cshuffle
+1058
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+257
-51
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+0
-2316
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+168
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+36
-0
include/ck/utility/philox_rand.hpp
include/ck/utility/philox_rand.hpp
+3
-6
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
66052232
...
...
@@ -10,8 +10,12 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_
add_example_executable
(
example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp
)
<<<<<<< HEAD
add_example_executable
(
example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp
)
=======
>>>>>>> attn-bwd-dropout
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
66052232
...
...
@@ -43,23 +43,27 @@ Kernel outputs:
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
VElementOp
=
Scale
;
using
DataType
=
F16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
ZDataType
=
U16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
...
@@ -91,6 +95,7 @@ using DeviceGemmInstance =
NumDimK
,
NumDimO
,
DataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
...
...
@@ -182,12 +187,16 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough
,
PassThrough
,
Scale
>
;
// Ref dropout
using
ReferenceDropoutInstance
=
ck
::
tensor_operation
::
host
::
ReferenceDropout
<
ushort
,
DataType
,
DataType
>
;
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorV
,
typename
TensorS
,
typename
TensorP
,
typename
TensorZ
,
typename
TensorY
,
typename
TensorLSE
=
TensorP
>
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
...
...
@@ -197,7 +206,11 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
TensorY
&
y_g_m_o
,
TensorLSE
&
lse_g_m
)
TensorLSE
&
lse_g_m
,
TensorP
&
p_drop_g_m_n
,
TensorZ
&
z_g_m_n
,
ushort
p_dropout_in_16bits
,
float
rp_dropout
)
{
// S = alpha * Q * K^T
auto
k_g_k_n
=
k_g_n_k
.
Transpose
({
0
,
2
,
1
});
...
...
@@ -225,11 +238,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// Y = P * V
// P_dropped
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P_dropout * V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
p_
drop_
g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
...
...
@@ -256,6 +276,13 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
if
(
argc
==
1
)
{
// use default case
...
...
@@ -321,6 +348,11 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
...
...
@@ -332,6 +364,7 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
DataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
DataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
...
...
@@ -339,10 +372,12 @@ int run(int argc, char* argv[])
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ks: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
switch
(
init_method
)
{
case
0
:
break
;
...
...
@@ -408,9 +443,11 @@ int run(int argc, char* argv[])
// calculate y & log-sum-exp beforehand
Tensor
<
DataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
DataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
...
...
@@ -418,12 +455,25 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
);
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
...
...
@@ -433,6 +483,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
z_device_buf
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
v_device_buf
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -443,6 +494,7 @@ int run(int argc, char* argv[])
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
...
...
@@ -450,11 +502,59 @@ int run(int argc, char* argv[])
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
// get z matrix
{
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
ygrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
}
// not need output z matrix
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
DataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
// set to nullptr
static_cast
<
DataType
*>
(
v_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
y_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
...
...
@@ -468,6 +568,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks_strides
,
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
,
y_gs_ms_os_lengths
,
...
...
@@ -481,15 +583,11 @@ int run(int argc, char* argv[])
QKVElementOp
{},
Scale
{
alpha
},
QKVElementOp
{},
YElementOp
{});
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
}
YElementOp
{},
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
// 5 GEMM ops in total:
...
...
@@ -511,9 +609,32 @@ int run(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// copy z matirx data form device
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
if
(
do_verification
)
{
// run fowad again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
// call kernel again
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
...
...
@@ -523,6 +644,7 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
DataType
>
sgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
pgrad_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
DataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
DataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
...
...
@@ -544,20 +666,26 @@ int run(int argc, char* argv[])
auto
ref_gemm_grad_invoker
=
ref_gemm_grad
.
MakeInvoker
();
using
RefGemmGradArg
=
ReferenceGemmGradInstance
::
Argument
;
// dP = dY * V^T
// dP
_dropout
= dY * V^T
auto
v_g_o_n
=
v_g_n_o
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
ygrad_g_m_o
,
v_g_o_n
,
pgrad_g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
ygrad_g_m_o
,
v_g_o_n
,
pgrad_
drop_
g_m_n
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_
o
;
std
::
cout
<<
"ygrad_
drop_
g_m_o ref:
\n
"
<<
ygrad_
drop_
g_m_
n
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"pgrad_g_m_n ref:
\n
"
<<
pgrad_g_m_n
;
std
::
cout
<<
"pgrad_
drop_
g_m_n ref:
\n
"
<<
pgrad_
drop_
g_m_n
;
}
#endif
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// dP = dP_dropout x Z
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
float
ygrad_dot_y
=
0
;
for
(
int
o
=
0
;
o
<
O
;
o
++
)
...
...
@@ -578,15 +706,14 @@ int run(int argc, char* argv[])
std
::
cout
<<
"sgrad_g_m_n ref:
\n
"
<<
sgrad_g_m_n
;
}
#endif
// dV = P^T * dY
auto
p_g_n_m
=
p_g_m_n
.
Transpose
({
0
,
2
,
1
});
// dV = P_drop^T * dY
auto
p_drop_g_n_m
=
p_drop_g_m_n
.
Transpose
({
0
,
2
,
1
});
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
p_g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
f
}});
p_
drop_
g_n_m
,
ygrad_g_m_o
,
vgrad_g_n_o
,
PassThrough
{},
PassThrough
{},
Scale
{
1.
0
f
}});
#if PRINT_HOST
{
std
::
cout
<<
"===== dV = P^T * dY
\n
"
;
std
::
cout
<<
"p_g_n_m ref:
\n
"
<<
p_g_n_m
;
std
::
cout
<<
"p_
drop_
g_n_m ref:
\n
"
<<
p_
drop_
g_n_m
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_o
;
std
::
cout
<<
"vgrad_g_n_o ref:
\n
"
<<
vgrad_g_n_o
;
}
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16_dropout.cpp
deleted
100644 → 0
View file @
5eb5e316
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
100755 → 100644
View file @
66052232
...
...
@@ -33,6 +33,7 @@ using S = ck::Sequence<Is...>;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -42,6 +43,7 @@ using B1DataType = F16;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
...
@@ -69,6 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
...
...
@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType
,
B1DataType
,
CDataType
,
ZDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
...
...
@@ -159,4 +163,5 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_grouped_multihead_attention_forward.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
66052232
...
...
@@ -48,6 +48,7 @@ int run(int argc, char* argv[])
std
::
vector
<
const
void
*>
p_b0
;
std
::
vector
<
const
void
*>
p_b1
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
void
*>
p_z
;
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
...
...
@@ -55,6 +56,7 @@ int run(int argc, char* argv[])
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B1DataType
>>
b1_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
...
...
@@ -62,6 +64,7 @@ int run(int argc, char* argv[])
std
::
vector
<
DeviceMemPtr
>
b0_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
b1_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
c_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
z_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
lse_tensors_device
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
...
...
@@ -101,6 +104,12 @@ int run(int argc, char* argv[])
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
...
...
@@ -114,6 +123,8 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
{},
// acc0_biases_gs_ms_ns_lengths
...
...
@@ -126,6 +137,7 @@ int run(int argc, char* argv[])
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
int
Batch
=
G0
*
G1
;
...
...
@@ -140,10 +152,13 @@ int run(int argc, char* argv[])
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
std
::
endl
;
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
switch
(
init_method
)
{
case
0
:
break
;
...
...
@@ -172,6 +187,7 @@ int run(int argc, char* argv[])
b0_tensors
.
push_back
(
b0_gs_ns_ks
);
b1_tensors
.
push_back
(
b1_gs_os_ns
);
c_tensors
.
push_back
(
c_gs_ms_os_device_result
);
z_tensors
.
push_back
(
z_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms_device_result
);
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
...
...
@@ -182,6 +198,8 @@ int run(int argc, char* argv[])
sizeof
(
B1DataType
)
*
b1_gs_os_ns
.
mDesc
.
GetElementSpaceSize
()));
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
()));
z_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ZDataType
)
*
z_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
()));
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms_device_result
.
mDesc
.
GetElementSpaceSize
()));
...
...
@@ -193,6 +211,7 @@ int run(int argc, char* argv[])
p_b0
.
push_back
(
b0_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b1
.
push_back
(
b1_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_c
.
push_back
(
c_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_z
.
push_back
(
z_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_lse
.
push_back
(
lse_tensors_device
[
i
]
->
GetDeviceBuffer
());
}
...
...
@@ -209,6 +228,7 @@ int run(int argc, char* argv[])
p_b0
,
p_b1
,
p_c
,
p_z
,
p_lse
,
{},
// p_acc0_biases
{},
// p_acc1_biases
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
66052232
...
...
@@ -79,6 +79,7 @@ template <index_t NumDimG,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
...
...
@@ -104,6 +105,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
index_t
>
z_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
z_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
...
...
@@ -119,6 +123,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std
::
vector
<
const
void
*>
p_b0_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
View file @
66052232
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_
batch
ed_multihead_attention_
back
ward_
train_
xdl_cshuffle
.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_
group
ed_multihead_attention_
for
ward_xdl_cshuffle
View file @
66052232
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
66052232
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
deleted
100644 → 0
View file @
5eb5e316
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
66052232
...
...
@@ -35,6 +35,7 @@ template <typename FloatAB,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -97,6 +98,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
...
...
@@ -116,6 +119,65 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// C desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
ZGridDesc_M_N
&
z_grid_desc_m_n
)
////=> for z use
{
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
////=> for z use
{
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
MPerBlock
,
MXdlPerWave
,
Gemm0MWaves
,
MPerXdl
)),
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
Gemm0MWaves
,
Gemm0NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetGemm0WaveMNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_mn_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
WaveSize
/
MPerXdl
,
MPerXdl
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_mn_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
...
@@ -323,6 +385,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ZGridDesc_M_N
{}))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
...
...
@@ -367,6 +432,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
unsigned
short
*
__restrict__
p_z_grid
,
FloatLSE
*
__restrict__
p_lse_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
...
...
@@ -379,6 +445,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
...
...
@@ -782,6 +850,79 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
///////////////////=>z for dropout
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// registerNum
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
decltype
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
9
,
// DstVectorDim
n4
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx
[
I0
],
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
///////////////////=>z for dropout
do
{
auto
n_block_data_idx_on_grid
=
...
...
@@ -876,9 +1017,35 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if
constexpr
(
IsDropout
)
// dropout
{
blockwise_dropout
.
ApplyDropout
(
acc_thread_buf
,
ph
);
// save z to global
if
(
p_z_grid
)
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
}
else
{
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
true
>(
acc_thread_buf
,
ph
);
}
}
//if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
...
...
include/ck/utility/data_type.hpp
View file @
66052232
...
...
@@ -1010,6 +1010,42 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert fp16 to bf16
template
<
>
inline
__host__
__device__
bhalf_t
type_convert
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
static_cast
<
float
>
(
x
)};
return
uint16_t
(
u
.
int32
>>
16
);
}
template
<
>
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
half2_t
>
(
half2_t
x
)
{
float
y0
{
0
},
y1
{
0
};
bhalf2_t
y
{
0
};
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1
\n
\
"
:
"=v"
(
y0
)
:
"v"
(
x
));
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1
\n
\
"
:
"=v"
(
y1
)
:
"v"
(
x
));
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1]
\n
\
"
:
"=v"
(
y
)
:
"v"
(
y0
),
"v"
(
y1
));
return
y
;
}
template
<
typename
T
>
struct
NumericLimits
{
...
...
include/ck/utility/philox_rand.hpp
View file @
66052232
...
...
@@ -109,12 +109,9 @@ class philox
__device__
uint2
u32_high_low_multi
(
const
unsigned
int
a
,
const
unsigned
int
b
)
{
uint2
*
res
;
uint2
tmp_res
;
asm
(
"v_mul_hi_u32 %0, %2, %3
\n\t
"
"v_mul_lo_u32 %1, %2, %3
\n\t
"
:
"=v"
(
tmp_res
.
x
),
"=v"
(
tmp_res
.
y
)
:
"v"
(
a
),
"v"
(
b
));
res
=
&
tmp_res
;
unsigned
long
long
tmp
;
tmp
=
static_cast
<
unsigned
long
long
>
(
a
)
*
b
;
res
=
reinterpret_cast
<
uint2
*>
(
&
tmp
);
return
*
res
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment