Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
959ddcf8
"vscode:/vscode.git/clone" did not exist on "3b8664611508c8e8999a45be3f58ea9e7f4ed010"
Commit
959ddcf8
authored
Nov 02, 2023
by
Jing Zhang
Browse files
clang-format
parent
db2e4cf4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
14 deletions
+12
-14
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+7
-7
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+4
-6
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
959ddcf8
...
...
@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
m_padded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
index_t
n_padded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
K_BATCH
);
const
index_t
k0_padded
=
GridwiseGemm
::
CalculateK0Padded
(
K
,
K_BATCH
);
const
index_t
m_padded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
index_t
n_padded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
K_BATCH
);
const
index_t
k0_padded
=
GridwiseGemm
::
CalculateK0Padded
(
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_c
);
...
...
@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k0_padded
=
GridwiseGemm
::
CalculateK0Padded
(
karg
.
K
,
K_BATCH
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k0_padded
=
GridwiseGemm
::
CalculateK0Padded
(
karg
.
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
...
...
@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
karg
.
KPadded
=
k_padded
;
karg
.
K0Padded
=
k0_padded
;
karg
.
K0Padded
=
k0_padded
;
karg
.
k_batch
=
K_BATCH
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
grouped_block_2_ctile_map
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
959ddcf8
...
...
@@ -133,8 +133,6 @@ void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
...
...
@@ -199,13 +197,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
...
...
profiler/src/profile_grouped_gemm.cpp
View file @
959ddcf8
...
...
@@ -27,7 +27,7 @@ enum struct GemmDataType
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
F16_F8_F16
,
// 4
F16_F8_F16
,
// 4
};
#define OP_NAME "grouped_gemm"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment