Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e5d6cf9c
Commit
e5d6cf9c
authored
Jan 20, 2025
by
rtmadduri
Browse files
clang formatting
parent
220f40c9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
18 deletions
+14
-18
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+14
-18
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
e5d6cf9c
...
@@ -200,9 +200,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -200,9 +200,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
ComputeTypeA
,
ComputeTypeA
,
ComputeTypeB
>
;
ComputeTypeB
>
;
using
Block2ETileMap
=
using
Block2ETileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMap
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMap
>
;
using
KernelArgument
=
typename
GridwiseGemm
::
Argument
;
using
KernelArgument
=
typename
GridwiseGemm
::
Argument
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -215,11 +214,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -215,11 +214,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
KernelArgument
&&
karg
,
GemmTransKernelArg
(
KernelArgument
&&
karg
,
// GroupedGemmBlock2ETileMap&& b2c_map,
// GroupedGemmBlock2ETileMap&& b2c_map,
index_t
block_start
,
index_t
block_start
,
index_t
block_end
)
index_t
block_end
)
:
karg_
{
karg
},
:
karg_
{
karg
},
// block_2_ctile_map_{b2c_map},
// block_2_ctile_map_{b2c_map},
block_start_
{
block_start
},
block_start_
{
block_start
},
block_end_
{
block_end
}
block_end_
{
block_end
}
{
{
...
@@ -278,11 +277,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -278,11 +277,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
M
,
N
,
K_BATCH
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
M
,
N
,
K_BATCH
);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
gdx
,
gdy
,
gdz
};
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
gdx
,
gdy
,
gdz
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
// const index_t grid_size_grp = gdx * gdy * gdz;
// const index_t grid_size_grp = gdx * gdy * gdz;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_start
=
grid_size_
;
...
@@ -306,12 +305,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -306,12 +305,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
K_BATCH
};
K_BATCH
};
// gemm_kernel_args_.emplace_back(
// gemm_kernel_args_.emplace_back(
// std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
// std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
}
}
}
}
...
@@ -333,8 +329,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -333,8 +329,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// const index_t m_padded = GridwiseGemm::CalculateMPadded(karg.M);
// const index_t m_padded = GridwiseGemm::CalculateMPadded(karg.M);
// const index_t n_padded = GridwiseGemm::CalculateNPadded(karg.N);
// const index_t n_padded = GridwiseGemm::CalculateNPadded(karg.N);
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
KBatch
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
KBatch
);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
gdx
,
gdy
,
gdz
};
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
gdx
,
gdy
,
gdz
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
);
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
);
...
@@ -344,14 +341,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -344,14 +341,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// auto grouped_block_2_ctile_map =
// auto grouped_block_2_ctile_map =
// GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
// GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg
.
KBatch
=
K_BATCH
;
karg
.
KBatch
=
K_BATCH
;
// gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
// gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment