Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cf875a91
Commit
cf875a91
authored
Aug 23, 2022
by
Anthony Chang
Browse files
properly check size K in IsSupportedArgument()
parent
55b19fce
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
18 deletions
+27
-18
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
...tion/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
+11
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+9
-1
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
+7
-13
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
cf875a91
...
@@ -448,13 +448,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -448,13 +448,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
}
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
},
lengths_m_n_k_o_
{
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
block_2_ctile_map_
,
lengths_m_n_k_o_
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
@@ -481,6 +483,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -481,6 +483,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
batch_count_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
lengths_m_n_k_o_
;
};
};
// Invoker
// Invoker
...
@@ -494,7 +499,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -494,7 +499,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
arg
.
block_2_ctile_map_
,
arg
.
lengths_m_n_k_o_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
...
@@ -588,7 +594,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
...
@@ -588,7 +594,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
,
arg
.
lengths_m_n_k_o_
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
cf875a91
...
@@ -200,7 +200,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -200,7 +200,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
std
::
vector
<
index_t
>&
lengths_m_n_k_o
)
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
@@ -216,6 +217,13 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -216,6 +217,13 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return
false
;
return
false
;
}
}
// K is rounded to nearest multiples of K1 during tensor transformation so instead get KRaw
const
auto
KRaw
=
lengths_m_n_k_o
[
2
];
if
(
!
(
KRaw
%
AK1
==
0
&&
KRaw
%
BK1
==
0
))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
Gemm1N
%
Gemm1NPerBlock
==
0
))
{
{
...
...
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
View file @
cf875a91
...
@@ -68,7 +68,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
...
@@ -68,7 +68,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
this
->
Run
();
this
->
Run
();
}
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
DISABLED_Test_FP16_OddK
)
// Currently expected that no kernels can support this case
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddK
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
33
,
128
,
1
},
{
128
,
128
,
33
,
128
,
1
},
...
@@ -108,7 +109,7 @@ using ck::tensor_operation::device::GemmSpecialization;
...
@@ -108,7 +109,7 @@ using ck::tensor_operation::device::GemmSpecialization;
TEST
(
TestBatchedGemmGemmInterface
,
GemmSpecializationSizeMatch
)
TEST
(
TestBatchedGemmGemmInterface
,
GemmSpecializationSizeMatch
)
{
{
int
P
=
12
9
;
// requires padding
int
P
=
12
0
;
// requires padding
int
Q
=
128
;
// do not require padding
int
Q
=
128
;
// do not require padding
// IsSupported(M, N, K, O)
// IsSupported(M, N, K, O)
...
@@ -134,18 +135,11 @@ TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMatch)
...
@@ -134,18 +135,11 @@ TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMatch)
TEST
(
TestBatchedGemmGemmInterface
,
GemmSpecializationSizeMismatch
)
TEST
(
TestBatchedGemmGemmInterface
,
GemmSpecializationSizeMismatch
)
{
{
int
P
=
129
;
// requires padding
int
Q
=
128
;
// do not require padding
// IsSupported(M, N, K, O)
// IsSupported(M, N, K, O)
// clang-format off
// clang-format off
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
120
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
// Kernel can't support odd K because K must be multiples of K1 values of either A or B
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
129
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
P
));
// clang-format on
// clang-format on
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment