Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ee54ef41
Commit
ee54ef41
authored
Mar 02, 2023
by
danyao12
Browse files
Fixed a bug where Z-matrix saved permute
parent
c97a3a0d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
63 deletions
+29
-63
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+9
-31
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
...ax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
+7
-28
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_fp16.cpp
...e_softmax_gemm/batched_multihead_attention_train_fp16.cpp
+13
-4
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
ee54ef41
...
@@ -59,7 +59,6 @@ using Scale = ck::tensor_operation::element_wise::Scale;
...
@@ -59,7 +59,6 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
VElementOp
=
Scale
;
using
DataType
=
F16
;
using
DataType
=
F16
;
using
GemmDataType
=
F16
;
using
GemmDataType
=
F16
;
...
@@ -533,30 +532,8 @@ int run(int argc, char* argv[])
...
@@ -533,30 +532,8 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
...
@@ -574,11 +551,7 @@ int run(int argc, char* argv[])
...
@@ -574,11 +551,7 @@ int run(int argc, char* argv[])
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -688,13 +661,15 @@ int run(int argc, char* argv[])
...
@@ -688,13 +661,15 @@ int run(int argc, char* argv[])
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// copy z matirx data form device
// copy z matirx data form device
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
// run f
owa
d again for y, cause z_g_m_n update
// run f
w
d again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
k_g_n_k
,
v_g_n_o
,
v_g_n_o
,
...
@@ -710,7 +685,10 @@ int run(int argc, char* argv[])
...
@@ -710,7 +685,10 @@ int run(int argc, char* argv[])
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
// call kernel again
// call kernel again
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
...
@@ -751,7 +729,7 @@ int run(int argc, char* argv[])
...
@@ -751,7 +729,7 @@ int run(int argc, char* argv[])
#if PRINT_HOST
#if PRINT_HOST
{
{
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"===== dP = dY * V^T
\n
"
;
std
::
cout
<<
"ygrad_
drop_
g_m_o ref:
\n
"
<<
ygrad_
drop_
g_m_
n
;
std
::
cout
<<
"ygrad_g_m_o ref:
\n
"
<<
ygrad_g_m_
o
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"v_g_o_n ref:
\n
"
<<
v_g_o_n
;
std
::
cout
<<
"pgrad_drop_g_m_n ref:
\n
"
<<
pgrad_drop_g_m_n
;
std
::
cout
<<
"pgrad_drop_g_m_n ref:
\n
"
<<
pgrad_drop_g_m_n
;
}
}
...
@@ -763,7 +741,7 @@ int run(int argc, char* argv[])
...
@@ -763,7 +741,7 @@ int run(int argc, char* argv[])
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
z_g_m_n
,
pgrad_drop_g_m_n
,
pgrad_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// dS_i_j = P_i_j .* (dP_i_j -
dY_i dot Y_i)
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
sgrad_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx_gmn
)
{
float
ygrad_dot_y
=
0
;
float
ygrad_dot_y
=
0
;
for
(
int
o
=
0
;
o
<
O
;
o
++
)
for
(
int
o
=
0
;
o
<
O
;
o
++
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
View file @
ee54ef41
...
@@ -349,7 +349,7 @@ int run(int argc, char* argv[])
...
@@ -349,7 +349,7 @@ int run(int argc, char* argv[])
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
bool
input_permute
=
false
;
bool
input_permute
=
true
;
//
false;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.2
;
float
p_drop
=
0.2
;
...
@@ -531,30 +531,8 @@ int run(int argc, char* argv[])
...
@@ -531,30 +531,8 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
,
p_drop_g_m_n
,
z_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
...
@@ -572,11 +550,7 @@ int run(int argc, char* argv[])
...
@@ -572,11 +550,7 @@ int run(int argc, char* argv[])
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -686,7 +660,9 @@ int run(int argc, char* argv[])
...
@@ -686,7 +660,9 @@ int run(int argc, char* argv[])
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// copy z matirx data form device
// copy z matirx data form device
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
bool
pass
=
true
;
...
@@ -708,7 +684,10 @@ int run(int argc, char* argv[])
...
@@ -708,7 +684,10 @@ int run(int argc, char* argv[])
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
// call kernel again
// call kernel again
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_fp16.cpp
View file @
ee54ef41
...
@@ -38,6 +38,7 @@ Kernel outputs:
...
@@ -38,6 +38,7 @@ Kernel outputs:
#include <numeric>
#include <numeric>
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...
@@ -418,12 +419,12 @@ int run(int argc, char* argv[])
...
@@ -418,12 +419,12 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
487
;
// 512
ck
::
index_t
M
=
129
;
// 512
ck
::
index_t
N
=
335
;
// 512
ck
::
index_t
N
=
128
;
// 512
ck
::
index_t
K
=
64
;
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G0
=
1
;
// 54
ck
::
index_t
G1
=
6
;
// 16
ck
::
index_t
G1
=
1
;
// 16
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
...
@@ -822,6 +823,10 @@ int run(int argc, char* argv[])
...
@@ -822,6 +823,10 @@ int run(int argc, char* argv[])
y_device_buf
.
FromDevice
(
y_gs_ms_os_device_result
.
mData
.
data
());
y_device_buf
.
FromDevice
(
y_gs_ms_os_device_result
.
mData
.
data
());
lse_device_buf
.
FromDevice
(
lse_gs_ms_device_result
.
mData
.
data
());
lse_device_buf
.
FromDevice
(
lse_gs_ms_device_result
.
mData
.
data
());
// std::cout << "z_fwd_gs_ms_ns ref:\n" << z_fwd_gs_ms_ns;
std
::
ofstream
fwd_file
(
"./z_fwd_matrix_txt"
);
fwd_file
<<
z_fwd_gs_ms_ns
<<
std
::
endl
;
kgrad_device_buf
.
SetZero
();
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
...
@@ -875,6 +880,10 @@ int run(int argc, char* argv[])
...
@@ -875,6 +880,10 @@ int run(int argc, char* argv[])
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
vgrad_device_buf
.
FromDevice
(
vgrad_gs_os_ns_device_result
.
mData
.
data
());
vgrad_device_buf
.
FromDevice
(
vgrad_gs_os_ns_device_result
.
mData
.
data
());
// std::cout << "z_bwd_gs_ms_ns ref:\n" << z_bwd_gs_ms_ns;
std
::
ofstream
bwd_file
(
"./z_bwd_matrix_txt"
);
bwd_file
<<
z_bwd_gs_ms_ns
<<
std
::
endl
;
}
}
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment