Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bf86f84b
"...composable_kernel.git" did not exist on "c72a0d3ed909689112a2caa66f26085229452e11"
Commit
bf86f84b
authored
Aug 23, 2022
by
Anthony Chang
Browse files
properly check size requirement given SrcScalarPerVector in IsSupportedArgument()
parent
cf875a91
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
6 deletions
+34
-6
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
...tion/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
+30
-5
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
+4
-1
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
bf86f84b
...
@@ -449,14 +449,14 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -449,14 +449,14 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
},
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
},
lengths_m_n_k_o_
{
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
}
raw_
lengths_m_n_k_o_
{
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
,
block_2_ctile_map_
,
lengths_m_n_k_o_
))
raw_
lengths_m_n_k_o_
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
@@ -485,7 +485,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -485,7 +485,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
// For robust IsSupportedArgument() check
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
lengths_m_n_k_o_
;
std
::
vector
<
index_t
>
raw_
lengths_m_n_k_o_
;
};
};
// Invoker
// Invoker
...
@@ -500,7 +500,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -500,7 +500,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
lengths_m_n_k_o_
))
arg
.
raw_
lengths_m_n_k_o_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
...
@@ -590,12 +590,37 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -590,12 +590,37 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
return
false
;
return
false
;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const
auto
MRaw
=
arg
.
raw_lengths_m_n_k_o_
[
0
];
const
auto
NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
1
];
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
lengths_m_n_k_o_
);
arg
.
raw_
lengths_m_n_k_o_
);
}
}
// polymorphic
// polymorphic
...
...
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
View file @
bf86f84b
...
@@ -78,7 +78,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
...
@@ -78,7 +78,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
this
->
Run
();
this
->
Run
();
}
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
DISABLED_Test_FP16_OddO
)
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddO
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
129
,
1
},
{
128
,
128
,
32
,
129
,
1
},
...
@@ -141,5 +142,7 @@ TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMismatch)
...
@@ -141,5 +142,7 @@ TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMismatch)
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
120
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
120
));
// Kernel can't support odd K because K must be multiples of K1 values of either A or B
// Kernel can't support odd K because K must be multiples of K1 values of either A or B
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
129
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
129
,
128
));
// Kernel can't support odd O size because B1SrcScalarPerVector=8 and must satisfy SizeO % 8 == 0
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
129
));
// clang-format on
// clang-format on
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment