Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5e1b0dd8
Commit
5e1b0dd8
authored
Aug 31, 2022
by
Anthony Chang
Browse files
IsSupportedArgument checks
parent
c3920de4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
3 deletions
+68
-3
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+38
-2
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
.../device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+30
-1
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
5e1b0dd8
...
...
@@ -544,7 +544,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
c_grid_desc_g_m_n_
}
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
c_grid_desc_g_m_n_
},
raw_lengths_m_n_k_o_
{
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
},
c_extent_lowest_
{
c_gs_ms_gemm1ns_lengths
.
back
()},
c_stride_lowest_
{
c_gs_ms_gemm1ns_strides
.
back
()}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -578,6 +581,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CElementwiseOperation
c_element_op_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
raw_lengths_m_n_k_o_
;
index_t
c_extent_lowest_
;
index_t
c_stride_lowest_
;
};
// Invoker
...
...
@@ -692,7 +700,35 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
false
;
}
// TODO: Check A/B0/B1 length & stride and scalar per vector
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const
auto
MRaw
=
arg
.
raw_lengths_m_n_k_o_
[
0
];
const
auto
NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
1
];
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
arg
.
c_extent_lowest_
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
// Check vector store requirement; assumes last dimension in N to be contiguous
if
(
arg
.
c_stride_lowest_
!=
1
)
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
5e1b0dd8
...
...
@@ -459,7 +459,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
}
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
},
raw_lengths_m_n_k_o_
{
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
...
...
@@ -492,6 +493,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation
c_element_op_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
raw_lengths_m_n_k_o_
;
};
// Invoker
...
...
@@ -595,6 +599,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
false
;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const
auto
MRaw
=
arg
.
raw_lengths_m_n_k_o_
[
0
];
const
auto
NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
1
];
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment