Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c75a3c17
"awq/modules/git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "66adba4a3397dadd4c39d62c504902ed8d0a38dc"
Commit
c75a3c17
authored
Feb 09, 2023
by
guangzlu
Browse files
added z tensor for dropout storing
parent
6b041227
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
0 deletions
+24
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
..._grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
+19
-0
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
c75a3c17
...
@@ -79,6 +79,7 @@ template <index_t NumDimG,
...
@@ -79,6 +79,7 @@ template <index_t NumDimG,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
Acc1BiasDataType
,
...
@@ -104,6 +105,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
...
@@ -104,6 +105,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
index_t
>
z_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
z_gs_ms_os_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
...
@@ -119,6 +123,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
...
@@ -119,6 +123,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
std
::
vector
<
const
void
*>
p_b0_vec
,
std
::
vector
<
const
void
*>
p_b0_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
View file @
c75a3c17
...
@@ -92,6 +92,8 @@ __global__ void
...
@@ -92,6 +92,8 @@ __global__ void
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
...
@@ -100,6 +102,7 @@ __global__ void
...
@@ -100,6 +102,7 @@ __global__ void
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -111,6 +114,7 @@ __global__ void
...
@@ -111,6 +114,7 @@ __global__ void
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
////////
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
@@ -361,11 +365,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -361,11 +365,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
index_t
BatchStrideLSE
)
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
BatchStrideLSE_
(
BatchStrideLSE
)
BatchStrideLSE_
(
BatchStrideLSE
)
{
{
}
}
...
@@ -390,6 +396,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -390,6 +396,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetLSEBasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetLSEBasePtr
(
index_t
g_idx
)
const
{
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
...
@@ -400,6 +411,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -400,6 +411,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n_
;
index_t
BatchStrideLSE_
;
index_t
BatchStrideLSE_
;
};
};
...
@@ -612,6 +624,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -612,6 +624,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_grid_desc_g_n_k
,
b_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
// C0 mask
...
@@ -635,11 +648,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -635,11 +648,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_b_grid
,
p_b_grid
,
p_b1_grid
,
p_b1_grid
,
p_c_grid
,
p_c_grid
,
p_z_grid
,
p_lse_grid
,
p_lse_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_g_m_n
,
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
compute_base_ptr_of_batch
,
compute_base_ptr_of_batch
,
...
@@ -912,6 +927,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -912,6 +927,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
...
@@ -928,6 +944,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -928,6 +944,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_b_vec
,
p_b_vec
,
p_b1_vec
,
p_b1_vec
,
p_c_vec
,
p_c_vec
,
p_z_vec
,
p_lse_vec
,
p_lse_vec
,
p_acc0_biases_vec
,
p_acc0_biases_vec
,
p_acc1_biases_vec
,
p_acc1_biases_vec
,
...
@@ -949,6 +966,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -949,6 +966,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
const
void
*>
p_b1_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_c_vec
,
std
::
vector
<
void
*>
p_z_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
void
*>
p_lse_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc0_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
std
::
vector
<
std
::
vector
<
const
void
*>>
p_acc1_biases_vec
,
...
@@ -965,6 +983,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -965,6 +983,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_b_vec
,
p_b_vec
,
p_b1_vec
,
p_b1_vec
,
p_c_vec
,
p_c_vec
,
p_z_vec
,
p_lse_vec
,
p_lse_vec
,
p_acc0_biases_vec
,
p_acc0_biases_vec
,
p_acc1_biases_vec
,
p_acc1_biases_vec
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment