Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
80da57fd
"include/vscode:/vscode.git/clone" did not exist on "ae2750404b2befe16f034e5e27365f0ff8087769"
Commit
80da57fd
authored
Aug 14, 2023
by
letaoqin
Browse files
fix bug for void type
parent
90f8550b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
21 deletions
+37
-21
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+35
-18
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+2
-3
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
80da57fd
...
@@ -110,6 +110,13 @@ __global__ void
...
@@ -110,6 +110,13 @@ __global__ void
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetD0BasePtr
(
g_idx
)));
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
tmp_p_d0_grid
=
p_d0_grid
+
d0_batch_offset
;
}
// const index_t global_thread_id = get_thread_global_1d_id();
// const index_t global_thread_id = get_thread_global_1d_id();
ck
::
philox
ph
(
seed
,
0
,
offset
);
ck
::
philox
ph
(
seed
,
0
,
offset
);
...
@@ -122,7 +129,7 @@ __global__ void
...
@@ -122,7 +129,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_d0_grid
==
nullptr
?
nullptr
:
p_d0_grid
+
d0_batch_offset
,
tmp_
p_d0_grid
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
...
@@ -155,7 +162,7 @@ __global__ void
...
@@ -155,7 +162,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_d0_grid
==
nullptr
?
nullptr
:
p_d0_grid
+
d0_batch_offset
,
tmp_
p_d0_grid
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
...
@@ -618,11 +625,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -618,11 +625,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
d0_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
)},
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
{
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d0_grid_desc_m_n_
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
...
@@ -633,8 +635,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -633,8 +635,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_g_n_k_
{
b_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
Transform
::
MakeB0GridDescriptor_G_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
d0_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
)},
b1_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
...
@@ -685,10 +685,23 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -685,10 +685,23 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
d0_grid_desc_m_n_
=
Transform
::
MakeCGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d0_grid_desc_m_n_
);
d0_grid_desc_g_m_n_
=
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
d0_n_length_stride_
.
push_back
(
acc0_biases_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_biases_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_biases_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
d0_n_length_stride_
.
push_back
(
acc0_biases_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
}
}
is_dropout_
=
p_dropout
>
0.0
;
//
is_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
p_dropout_in_16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
...
@@ -1010,6 +1023,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1010,6 +1023,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
return
false
;
return
false
;
}
}
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
if
(
arg
.
d0_n_length_stride_
[
1
]
==
1
&&
if
(
arg
.
d0_n_length_stride_
[
1
]
==
1
&&
arg
.
d0_n_length_stride_
[
0
]
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
arg
.
d0_n_length_stride_
[
0
]
%
Acc0BiasTransferSrcScalarPerVector
!=
0
)
{
{
...
@@ -1019,6 +1034,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1019,6 +1034,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
{
return
false
;
return
false
;
}
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
80da57fd
...
@@ -992,9 +992,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -992,9 +992,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
wave_m_n_id
[
I0
],
// NInputIndex
wave_m_n_id
[
I0
],
// NInputIndex
0
));
// register number
0
));
// register number
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
// for blockwise copy
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
DropoutNRepeat
,
// NRepeat
DropoutNRepeat
,
// NRepeat
...
@@ -1293,6 +1290,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1293,6 +1290,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// add bias
// add bias
if
constexpr
(
!
std
::
is_void
<
D0DataType
>::
value
)
if
constexpr
(
!
std
::
is_void
<
D0DataType
>::
value
)
{
{
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
// get register
// get register
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
,
D0DataType
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment