Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4d720be3
Commit
4d720be3
authored
Feb 28, 2023
by
danyao12
Browse files
remove unnecessary host run
parent
4d140b5d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
59 deletions
+72
-59
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_bf16.cpp
...ax_gemm/batched_multihead_attention_backward_pt1_bf16.cpp
+39
-31
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
...ax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
+33
-28
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_bf16.cpp
View file @
4d720be3
...
@@ -50,8 +50,8 @@ template <ck::index_t... Is>
...
@@ -50,8 +50,8 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
...
@@ -160,7 +160,7 @@ using DeviceGemmInstance =
...
@@ -160,7 +160,7 @@ using DeviceGemmInstance =
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#else
#else
//2nd template
//
2nd template
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
NumDimG
,
NumDimG
,
...
@@ -531,30 +531,32 @@ int run(int argc, char* argv[])
...
@@ -531,30 +531,32 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns
.
ForEach
(
// z_gs_ms_ns.ForEach(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// [&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
// });
v_gs_os_ns
.
ForEach
(
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
// lse_gs_ms.ForEach(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
// [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host
(
q_g_m_k
,
// run_attention_fwd_host(q_g_m_k,
k_g_n_k
,
// k_g_n_k,
v_g_n_o
,
// v_g_n_o,
alpha
,
// alpha,
s_g_m_n
,
// s_g_m_n,
p_g_m_n
,
// p_g_m_n,
y_g_m_o
,
// y_g_m_o,
lse_g_m
,
// lse_g_m,
p_drop_g_m_n
,
// p_drop_g_m_n,
z_g_m_n
,
// z_g_m_n,
p_dropout_in_16bits
,
// p_dropout_in_16bits,
rp_dropout
);
// rp_dropout);
y_gs_ms_os
.
ForEach
(
// y_gs_ms_os.ForEach(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
// [&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
lse_gs_ms
.
ForEach
(
// });
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
// lse_gs_ms.ForEach(
// [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
...
@@ -572,11 +574,11 @@ int run(int argc, char* argv[])
...
@@ -572,11 +574,11 @@ int run(int argc, char* argv[])
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
//
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
//
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
//
kgrad_device_buf.SetZero();
vgrad_device_buf
.
SetZero
();
//
vgrad_device_buf.SetZero();
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -708,7 +710,10 @@ int run(int argc, char* argv[])
...
@@ -708,7 +710,10 @@ int run(int argc, char* argv[])
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
// call kernel again
// call kernel again
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
...
@@ -768,9 +773,12 @@ int run(int argc, char* argv[])
...
@@ -768,9 +773,12 @@ int run(int argc, char* argv[])
{
{
auto
idx_gmo
=
idx_gmn
;
auto
idx_gmo
=
idx_gmn
;
idx_gmo
[
2
]
=
o
;
idx_gmo
[
2
]
=
o
;
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
ygrad_dot_y
+=
ck
::
type_convert
<
AccDataType
>
(
ygrad_g_m_o
(
idx_gmo
))
*
ck
::
type_convert
<
AccDataType
>
(
y_g_m_o
(
idx_gmo
));
}
}
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
self
(
idx_gmn
)
=
ck
::
type_convert
<
DataType
>
(
ck
::
type_convert
<
AccDataType
>
(
p_g_m_n
(
idx_gmn
))
*
(
ck
::
type_convert
<
AccDataType
>
(
pgrad_g_m_n
(
idx_gmn
))
-
ygrad_dot_y
));
});
});
#if PRINT_HOST
#if PRINT_HOST
{
{
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
View file @
4d720be3
...
@@ -344,8 +344,8 @@ int run(int argc, char* argv[])
...
@@ -344,8 +344,8 @@ int run(int argc, char* argv[])
ck
::
index_t
N
=
512
;
// 512
ck
::
index_t
N
=
512
;
// 512
ck
::
index_t
K
=
64
;
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G0
=
5
4
;
// 54
ck
::
index_t
G1
=
6
;
// 16
ck
::
index_t
G1
=
1
6
;
// 16
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
...
@@ -531,30 +531,32 @@ int run(int argc, char* argv[])
...
@@ -531,30 +531,32 @@ int run(int argc, char* argv[])
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns
.
ForEach
(
// z_gs_ms_ns.ForEach(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// [&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
// });
v_gs_os_ns
.
ForEach
(
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
// lse_gs_ms.ForEach(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
// [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host
(
q_g_m_k
,
// run_attention_fwd_host(q_g_m_k,
k_g_n_k
,
// k_g_n_k,
v_g_n_o
,
// v_g_n_o,
alpha
,
// alpha,
s_g_m_n
,
// s_g_m_n,
p_g_m_n
,
// p_g_m_n,
y_g_m_o
,
// y_g_m_o,
lse_g_m
,
// lse_g_m,
p_drop_g_m_n
,
// p_drop_g_m_n,
z_g_m_n
,
// z_g_m_n,
p_dropout_in_16bits
,
// p_dropout_in_16bits,
rp_dropout
);
// rp_dropout);
y_gs_ms_os
.
ForEach
(
// y_gs_ms_os.ForEach(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
// [&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
lse_gs_ms
.
ForEach
(
// });
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
// lse_gs_ms.ForEach(
// [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
...
@@ -572,11 +574,11 @@ int run(int argc, char* argv[])
...
@@ -572,11 +574,11 @@ int run(int argc, char* argv[])
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
//
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
//
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
//
kgrad_device_buf.SetZero();
vgrad_device_buf
.
SetZero
();
//
vgrad_device_buf.SetZero();
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -708,7 +710,10 @@ int run(int argc, char* argv[])
...
@@ -708,7 +710,10 @@ int run(int argc, char* argv[])
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
// call kernel again
// call kernel again
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment