Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
88980945
Commit
88980945
authored
Mar 06, 2023
by
guangzlu
Browse files
updated new dropout for attn fwd
parent
db8018de
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
57 additions
and
49 deletions
+57
-49
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
...cale_softmax_gemm/batched_multihead_attention_forward.cpp
+6
-6
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_bf16.cpp
...softmax_gemm/batched_multihead_attention_forward_bf16.cpp
+6
-6
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
...cale_softmax_gemm/grouped_multihead_attention_forward.cpp
+5
-5
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_bf16.cpp
...softmax_gemm/grouped_multihead_attention_forward_bf16.cpp
+5
-5
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+3
-3
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+4
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+28
-19
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
View file @
88980945
...
@@ -103,17 +103,17 @@ using DeviceGemmInstance =
...
@@ -103,17 +103,17 @@ using DeviceGemmInstance =
TensorSpecC
,
TensorSpecC
,
1
,
1
,
256
,
256
,
256
,
// MPerBlock
128
,
// MPerBlock
128
,
// NPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// KPerBlock
64
,
// Gemm1NPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
64
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// AK1
8
,
// BK1
8
,
// BK1
2
,
// B1K1
2
,
// B1K1
32
,
// MPerXDL
32
,
// MPerXDL
32
,
// NPerXDL
32
,
// NPerXDL
2
,
// MXdlPerWave
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
4
,
64
,
1
>
,
// ABlockTransfer
...
@@ -130,11 +130,11 @@ using DeviceGemmInstance =
...
@@ -130,11 +130,11 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
2
,
2
,
2
,
false
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_bf16.cpp
View file @
88980945
...
@@ -99,17 +99,17 @@ using DeviceGemmInstance =
...
@@ -99,17 +99,17 @@ using DeviceGemmInstance =
TensorSpecC
,
TensorSpecC
,
1
,
1
,
256
,
256
,
256
,
// MPerBlock
128
,
// MPerBlock
128
,
// NPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// KPerBlock
64
,
// Gemm1NPerBlock
64
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
64
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// AK1
8
,
// BK1
8
,
// BK1
2
,
// B1K1
2
,
// B1K1
32
,
// MPerXDL
32
,
// MPerXDL
32
,
// NPerXDL
32
,
// NPerXDL
2
,
// MXdlPerWave
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
4
,
64
,
1
>
,
// ABlockTransfer
...
@@ -126,11 +126,11 @@ using DeviceGemmInstance =
...
@@ -126,11 +126,11 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
2
,
2
,
2
,
false
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward.cpp
View file @
88980945
...
@@ -105,8 +105,8 @@ using DeviceGemmInstance =
...
@@ -105,8 +105,8 @@ using DeviceGemmInstance =
256
,
256
,
128
,
// MPerBlock
128
,
// MPerBlock
128
,
// NPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1NPerBlock
64
,
// Gemm1KPerBlock
64
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// AK1
8
,
// BK1
8
,
// BK1
...
@@ -115,7 +115,7 @@ using DeviceGemmInstance =
...
@@ -115,7 +115,7 @@ using DeviceGemmInstance =
32
,
// NPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -130,11 +130,11 @@ using DeviceGemmInstance =
...
@@ -130,11 +130,11 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
2
,
2
,
2
,
false
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_bf16.cpp
View file @
88980945
...
@@ -101,8 +101,8 @@ using DeviceGemmInstance =
...
@@ -101,8 +101,8 @@ using DeviceGemmInstance =
256
,
256
,
128
,
// MPerBlock
128
,
// MPerBlock
128
,
// NPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// KPerBlock
128
,
// Gemm1NPerBlock
64
,
// Gemm1NPerBlock
64
,
// Gemm1KPerBlock
64
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// AK1
8
,
// BK1
8
,
// BK1
...
@@ -111,7 +111,7 @@ using DeviceGemmInstance =
...
@@ -111,7 +111,7 @@ using DeviceGemmInstance =
32
,
// NPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -126,11 +126,11 @@ using DeviceGemmInstance =
...
@@ -126,11 +126,11 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
16
,
16
,
1
>
,
// B1BlockTransfer
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
2
,
2
,
2
,
false
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
88980945
...
@@ -9,8 +9,8 @@ int run(int argc, char* argv[])
...
@@ -9,8 +9,8 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
1000
;
// 120
ck
::
index_t
M
=
512
;
// 120
ck
::
index_t
N
=
1000
;
// 1000
ck
::
index_t
N
=
512
;
// 1000
ck
::
index_t
K
=
64
;
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
O
=
64
;
...
@@ -97,7 +97,7 @@ int run(int argc, char* argv[])
...
@@ -97,7 +97,7 @@ int run(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
out
put_permute
in
put_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
88980945
...
@@ -10,8 +10,7 @@ int run(int argc, char* argv[])
...
@@ -10,8 +10,7 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
float
p_drop
=
0.1
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
...
@@ -84,8 +83,8 @@ int run(int argc, char* argv[])
...
@@ -84,8 +83,8 @@ int run(int argc, char* argv[])
int
M
=
128
*
(
rand
()
%
8
+
1
);
int
M
=
128
*
(
rand
()
%
8
+
1
);
int
N
=
128
*
(
rand
()
%
8
+
1
);
int
N
=
128
*
(
rand
()
%
8
+
1
);
int
K
=
128
;
int
K
=
64
;
int
O
=
128
;
int
O
=
64
;
int
G0
=
rand
()
%
3
+
1
;
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
5
+
1
;
int
G1
=
rand
()
%
5
+
1
;
...
@@ -117,7 +116,7 @@ int run(int argc, char* argv[])
...
@@ -117,7 +116,7 @@ int run(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
out
put_permute
in
put_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
88980945
...
@@ -274,11 +274,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -274,11 +274,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
Gemm1N
=
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
//
if(Gemm1N != K)
if
(
Gemm1N
!=
K
)
//
{
{
//
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
std
::
cout
<<
"SizeK must be equal to SizeO (equal attention head size)"
<<
'\n'
;
//
return false;
return
false
;
//
}
}
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
{
...
@@ -852,7 +852,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -852,7 +852,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
n0
,
// NRepeat
I1
,
//
n0, // NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MPerXdl
...
@@ -883,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -883,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence
<
I1
,
// MBlockId
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
n0
,
// NRepeat
I1
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MPerXdl
...
@@ -1006,25 +1006,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1006,25 +1006,34 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// save z to global
// save z to global
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_tenor_buffer
),
false
>(
false
,
acc_thread_buf
,
ph
,
z_tenor_buffer
);
decltype
(
n0
),
decltype
(
i
)>(
z_thread_copy_vgpr_to_global
.
Run
(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_thread_copy_vgpr_to_global
.
Run
(
z_tenor_buffer
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
make_multi_index
(
0
,
0
,
0
,
-
(
n0
.
value
),
0
,
0
,
0
,
0
,
0
,
0
));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
}
else
else
{
{
// ignore = z_grid_buf;
// P_dropped
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
false
>(
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
false
>(
acc_thread_buf
,
ph
);
acc_thread_buf
,
ph
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment