Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c8f6d5d1
Commit
c8f6d5d1
authored
May 04, 2022
by
Chao Liu
Browse files
clean
parent
7b4de775
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
18 deletions
+18
-18
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+12
-12
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+4
-4
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+2
-2
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
c8f6d5d1
...
...
@@ -31,7 +31,7 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc
_
,
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc
s
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
...
...
@@ -44,31 +44,31 @@ __global__ void
#if 1
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_desc
_
[
i
].
BlockStart_
&&
block_id
<
gemm_desc
_
[
i
].
BlockEnd_
&&
if
(
block_id
>=
gemm_desc
s
[
i
].
BlockStart_
&&
block_id
<
gemm_desc
s
[
i
].
BlockEnd_
&&
i
<
group_count
)
{
auto
group_id
=
i
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_desc
_
[
group_id
].
a_ptr
,
gemm_desc
_
[
group_id
].
b_ptr
,
gemm_desc
_
[
group_id
].
c_ptr
,
gemm_desc
s
[
group_id
].
a_ptr
,
gemm_desc
s
[
group_id
].
b_ptr
,
gemm_desc
s
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc
_
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc
_
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc
_
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
gemm_desc
s
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc
s
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc
s
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc
_
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
gemm_desc
s
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
}
});
#else
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc
_
);
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc
s
);
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
group_id
=
(
block_id
>=
gemm_desc
_
[
i
].
BlockStart
&&
block_id
<
gemm_desc
_
[
i
].
BlockEnd
&&
group_id
=
(
block_id
>=
gemm_desc
s
[
i
].
BlockStart
&&
block_id
<
gemm_desc
s
[
i
].
BlockEnd
&&
i
<
group_count
)
?
i
:
group_id
;
...
...
@@ -91,7 +91,7 @@ __global__ void
block_id_grp
);
#endif
#else
ignore
=
gemm_desc
_
;
ignore
=
gemm_desc
s
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
...
...
include/ck/utility/amd_xdlops.hpp
View file @
c8f6d5d1
...
...
@@ -266,8 +266,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x8i8
(
bit_cast
<
int
>
(
reg_a
),
bit_cast
<
int
>
(
reg_b
),
__builtin_amdgcn_mfma_i32_32x32x8i8
(
bit_cast
<
int
32_t
>
(
reg_a
),
bit_cast
<
int
32_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
...
...
@@ -285,8 +285,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x16i8
(
bit_cast
<
int
>
(
reg_a
),
bit_cast
<
int
>
(
reg_b
),
__builtin_amdgcn_mfma_i32_16x16x16i8
(
bit_cast
<
int
32_t
>
(
reg_a
),
bit_cast
<
int
32_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
...
...
library/include/ck/library/utility/check_err.hpp
View file @
c8f6d5d1
...
...
@@ -169,8 +169,8 @@ check_err(const std::vector<T>& out,
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
int64_t
out_v
=
static_cast
<
int64_t
>
(
out
[
i
]);
const
int64_t
ref_v
=
static_cast
<
int64_t
>
(
ref
[
i
]);
const
auto
out_v
=
static_cast
<
int64_t
>
(
out
[
i
]);
const
auto
ref_v
=
static_cast
<
int64_t
>
(
ref
[
i
]);
if
(
out_v
!=
ref_v
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment