Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5718bc14
"vscode:/vscode.git/clone" did not exist on "229a65f305dfd1701a040a260096503206580268"
Commit
5718bc14
authored
Sep 20, 2022
by
wangshaojie6
Browse files
passing a mask struct
parent
4976eb0c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
12 deletions
+41
-12
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+33
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+7
-9
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+1
-0
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
5718bc14
...
...
@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
...
...
@@ -57,7 +58,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -88,7 +90,8 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
block_2_ctile_map
,
c0_matrix_mask
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
...
...
@@ -106,6 +109,7 @@ __global__ void
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
@@ -399,6 +403,26 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
MakeCGridDescriptor_G_M_N
({},
{}));
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct
C0MatrixMask
{
C0MatrixMask
(
index_t
NRaw
)
:
NRaw_
(
NRaw
)
{}
__host__
__device__
bool
IsUpperTriangle
(
index_t
m
,
index_t
n
)
const
{
return
n
>
m
;
}
__host__
__device__
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
return
n
>=
NRaw_
;
}
__host__
__device__
bool
IsMaskedElement
(
index_t
m
,
index_t
n
)
const
{
return
IsUpperTriangle
(
m
,
n
)
||
IsNOutOfBound
(
n
);
}
private:
// index_t MRaw_;
index_t
NRaw_
;
};
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
(
index_t
BatchStrideA
,
...
...
@@ -550,6 +574,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
c_grid_desc_g_m_n_
},
c0_matrix_mask_
{
NRaw
},
raw_lengths_m_n_k_o_
{
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
},
c_extent_lowest_
{
c_gs_ms_gemm1ns_lengths
.
back
()},
c_stride_lowest_
{
c_gs_ms_gemm1ns_strides
.
back
()}
...
...
@@ -587,6 +612,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
raw_lengths_m_n_k_o_
;
index_t
c_extent_lowest_
;
...
...
@@ -634,6 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -656,7 +685,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
);
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
5718bc14
...
...
@@ -366,7 +366,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
...
...
@@ -382,7 +382,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -405,9 +406,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
;
}
// fetch origin N dim(before padding)
const
index_t
n_raw
=
b_grid_desc_bk0_n_bk1
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
@@ -764,8 +762,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
auto
gemm0_n_block_idx
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
((
m_block_data_idx_on_grid
<
gemm0_n_block_idx
)
&&
(
(
m_block_data_idx_on_grid
+
MPerBlock
-
1
)
<
gemm0_n_block_idx
))
if
(
c0_matrix_mask
.
IsUpperTriangle
(
m_block_data_idx_on_grid
,
gemm0_n_block_idx
)
&&
c0_matrix_mask
.
IsUpperTriangle
(
m_block_data_idx_on_grid
+
MPerBlock
-
1
,
gemm0_n_block_idx
))
{
continue
;
}
...
...
@@ -809,7 +807,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
acc_offset
=
Number
<
acc_idx_n2
+
n4_i
>
{};
if
constexpr
(
MaskOutUpperTriangle
)
{
if
(
n_global
>
m_global
||
n_global
>
n_raw
)
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
n_global
)
)
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
...
@@ -823,7 +821,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
else
{
// ignore m_global;
if
(
n_global
>
n_raw
)
if
(
c0_matrix_mask
.
IsNOutOfBound
(
n_global
)
)
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
5718bc14
...
...
@@ -37,6 +37,7 @@ static constexpr auto GemmPadded =
using
device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// 2 of them are commented out because they trigger the clang-13 issue.
//##############################################| ALayout| B0Layout| B1Layout| CPermuteNumDims_G_M_O| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut|
//##############################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper|
//##############################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle|
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment