Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
220f40c9
Commit
220f40c9
authored
Jan 20, 2025
by
rtmadduri
Browse files
fix gridsize calculations
parent
1c1da090
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
18 deletions
+21
-18
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+21
-18
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
220f40c9
...
...
@@ -200,7 +200,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
ComputeTypeA
,
ComputeTypeB
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
Block2CTileMap
;
using
Block2ETileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMap
>
;
using
KernelArgument
=
typename
GridwiseGemm
::
Argument
;
...
...
@@ -209,16 +210,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct
GemmTransKernelArg
{
KernelArgument
karg_
;
GroupedGemmBlock2ETileMap
block_2_ctile_map_
;
//
GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t
block_start_
,
block_end_
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
KernelArgument
&&
karg
,
GroupedGemmBlock2ETileMap
&&
b2c_map
,
//
GroupedGemmBlock2ETileMap&& b2c_map,
index_t
block_start
,
index_t
block_end
)
:
karg_
{
karg
},
block_2_ctile_map_
{
b2c_map
},
//
block_2_ctile_map_{b2c_map},
block_start_
{
block_start
},
block_end_
{
block_end
}
{
...
...
@@ -277,15 +278,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
// const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
// const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
M
,
N
,
K_BATCH
);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
gdx
,
gdy
,
gdz
};
//
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(
c_grid_desc_m_n
);
const
index_t
grid_size_grp
=
gdx
*
gdy
*
gdz
;
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
//
const index_t grid_size_grp = gdx * gdy * gdz;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
...
...
@@ -293,8 +291,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_
+=
grid_size_grp
;
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
//
auto grouped_block_2_ctile_map =
//
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
KernelArgument
karg
{
type_convert
<
const
ADataType
*>
(
p_a_grid
[
i
]),
type_convert
<
const
BDataType
*>
(
p_b_grid
[
i
]),
...
...
@@ -307,8 +305,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
stride_c
,
K_BATCH
};
// gemm_kernel_args_.emplace_back(
// std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
std
::
move
(
grouped_block_2_ctile_map
),
block_start
,
block_end
);
std
::
move
(
karg
),
block_start
,
block_end
);
}
}
...
...
@@ -334,19 +337,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
KBatch
);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
gdx
,
gdy
,
gdz
};
//
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(
c_grid_desc_m_n
);
const
index_t
grid_size_grp
=
gdx
*
gdy
*
gdz
;
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
);
//
const index_t grid_size_grp = gdx * gdy * gdz;
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
//
auto grouped_block_2_ctile_map =
//
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg
.
KBatch
=
K_BATCH
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
grouped_block_2_ctile_map
;
//
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment