Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2723b268
Commit
2723b268
authored
Jun 16, 2023
by
guangzlu
Browse files
fixed bugs and standardize the code
parent
e38d2a5d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
53 additions
and
14 deletions
+53
-14
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
+8
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
+8
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle_v2.hpp
...e_batched_multihead_attention_forward_xdl_cshuffle_v2.hpp
+8
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt6.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt6.hpp
+11
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt7.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt7.hpp
+11
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
..._batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
+7
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
View file @
2723b268
...
...
@@ -819,6 +819,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n_
);
// Print();
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
n_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
}
void
Print
()
const
...
...
@@ -906,6 +909,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
p_drop_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
index_t
m_raw_padded_
;
index_t
n_raw_padded_
;
};
// Invoker
...
...
@@ -988,8 +994,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
,
arg
.
raw_
lengths_mz_nz_kz_gemm1nz_
[
0
]
,
arg
.
raw_
lengths_mz_nz_kz_gemm1nz_
[
1
]
);
arg
.
m_
raw_
padded_
,
arg
.
n_
raw_
padded_
);
};
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
View file @
2723b268
...
...
@@ -832,6 +832,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n_
);
// Print();
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
n_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
}
void
Print
()
const
...
...
@@ -919,6 +922,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float
p_drop_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
index_t
m_raw_padded_
;
index_t
n_raw_padded_
;
};
// Invoker
...
...
@@ -1005,8 +1011,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
,
arg
.
raw_
lengths_mz_nz_kz_gemm1nz_
[
0
]
,
arg
.
raw_
lengths_mz_nz_kz_gemm1nz_
[
1
]
);
arg
.
m_
raw_
padded_
,
arg
.
n_
raw_
padded_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle_v2.hpp
View file @
2723b268
...
...
@@ -648,6 +648,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
m_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
n_raw_padded_
=
GridwiseGemm
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
if
(
p_lse_grid
==
nullptr
)
{
is_lse_storing_
=
false
;
...
...
@@ -728,6 +731,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
bool
is_dropout_
;
bool
is_lse_storing_
=
true
;
index_t
m_raw_padded_
;
index_t
n_raw_padded_
;
};
// Invoker
...
...
@@ -813,8 +819,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg
.
p_dropout_rescale_
,
arg
.
seed_
,
arg
.
offset_
,
arg
.
raw_
lengths_mz_nz_kz_gemm1nz_
[
0
]
,
arg
.
raw_
lengths_mz_nz_kz_gemm1nz_
[
1
]
);
arg
.
m_
raw_
padded_
,
arg
.
n_
raw_
padded_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt6.hpp
View file @
2723b268
...
...
@@ -133,6 +133,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
1
,
3
,
5
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
index_t
(
ceil
(
float
(
size
)
/
M5
)
*
M5
);
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
...
...
@@ -1956,12 +1963,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
auto
global_elem_id
=
(
global_elem_id_raw
%
4
)
*
M
Raw
+
int
(
global_elem_id_raw
/
4
)
*
4
;
(
global_elem_id_raw
%
M
4
)
*
N
Raw
+
int
(
global_elem_id_raw
/
M
4
)
*
M
4
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
M
Raw
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
N
Raw
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -1983,12 +1990,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
auto
global_elem_id
=
(
global_elem_id_raw
%
4
)
*
M
Raw
+
int
(
global_elem_id_raw
/
4
)
*
4
;
(
global_elem_id_raw
%
M
4
)
*
N
Raw
+
int
(
global_elem_id_raw
/
M
4
)
*
M
4
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
M
Raw
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
N
Raw
);
}
block_sync_lds
();
// wait for gemm1 LDS read
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt7.hpp
View file @
2723b268
...
...
@@ -147,6 +147,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
1
,
3
,
5
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
M5
=
mfma
.
group_size
;
return
index_t
(
ceil
(
float
(
size
)
/
M5
)
*
M5
);
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
...
...
@@ -1872,12 +1879,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
auto
global_elem_id
=
(
global_elem_id_raw
%
4
)
*
M
Raw
+
int
(
global_elem_id_raw
/
4
)
*
4
;
(
global_elem_id_raw
%
M
4
)
*
N
Raw
+
int
(
global_elem_id_raw
/
M
4
)
*
M
4
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
M
Raw
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
N
Raw
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -1899,11 +1906,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
auto
global_elem_id
=
(
global_elem_id_raw
%
4
)
*
M
Raw
+
int
(
global_elem_id_raw
/
4
)
*
4
;
(
global_elem_id_raw
%
M
4
)
*
N
Raw
+
int
(
global_elem_id_raw
/
M
4
)
*
M
4
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
M
Raw
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
N
Raw
);
}
block_sync_lds
();
// wait for gemm1 LDS read
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
View file @
2723b268
...
...
@@ -143,6 +143,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
__host__
__device__
static
constexpr
auto
GetPaddedSize
(
const
index_t
size
)
{
constexpr
auto
mfma
=
MfmaSelector
<
FloatGemm
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N5
=
mfma
.
group_size
;
return
index_t
(
ceil
(
float
(
size
)
/
N5
)
*
N5
);
}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment