Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
928b6d1a
Commit
928b6d1a
authored
Dec 04, 2024
by
coderfeli
Browse files
split smem to 2array, but still same
parent
c275904b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
26 deletions
+22
-26
include/ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp
+3
-3
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+3
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+14
-17
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+2
-4
No files found.
include/ck_tile/ops/epilogue/cshuffle_epilogue_v2.hpp
View file @
928b6d1a
...
@@ -41,7 +41,7 @@ CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor()
...
@@ -41,7 +41,7 @@ CK_TILE_HOST_DEVICE static constexpr auto MakeOLdsBlockDescriptor()
}
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
65536
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
32768
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeODramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeODramTileDistribution
()
{
{
...
@@ -87,7 +87,7 @@ struct CShuffleEpilogueV2
...
@@ -87,7 +87,7 @@ struct CShuffleEpilogueV2
// static constexpr bool kMPerBlock = 64;
// static constexpr bool kMPerBlock = 64;
static
constexpr
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
static
constexpr
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
65536
;}
//
kMPerBlock * kNPerBlock * sizeof(ODataType); }
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
kMPerBlock
*
kNPerBlock
*
sizeof
(
ODataType
);
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
// how do we fix this ?
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
928b6d1a
...
@@ -161,13 +161,14 @@ struct GemmKernel
...
@@ -161,13 +161,14 @@ struct GemmKernel
{
i_n
,
0
});
{
i_n
,
0
});
// allocate LDS
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
__shared__
char
smem_ptr_0
[
GetSmemSize
()];
__shared__
char
smem_ptr_1
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
_0
,
smem_ptr_1
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
auto
c_tensor_view
=
[
&
]()
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
928b6d1a
...
@@ -165,7 +165,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -165,7 +165,8 @@ struct GemmPipelineAGmemBGmemCRegV1
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
index_t
num_loop
,
void
*
p_smem
)
const
void
*
__restrict__
p_smem_0
,
void
*
__restrict__
p_smem_1
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
...
@@ -209,26 +210,20 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -209,26 +210,20 @@ struct GemmPipelineAGmemBGmemCRegV1
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
constexpr
index_t
a_lds_block_space_size_aligned
=
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_least_multiple
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
);
integer_least_multiple
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
);
constexpr
index_t
b_lds_block_space_size_aligned
=
integer_least_multiple
(
sizeof
(
BDataType
)
*
b_lds_block_desc
.
get_element_space_size
(),
16
);
// A tile in LDS view
// A tile in LDS view
const
ADataType
*
__restrict__
p_a_lds0
=
reinterpret_cast
<
ADataType
*>
(
p_smem
);
const
ADataType
*
__restrict__
p_a_lds0
=
reinterpret_cast
<
ADataType
*>
(
p_smem_0
);
const
ADataType
*
__restrict__
p_a_lds1
=
reinterpret_cast
<
ADataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
);
const
ADataType
*
__restrict__
p_a_lds1
=
reinterpret_cast
<
ADataType
*>
(
p_smem_1
);
const
ADataType
*
__restrict__
p_a_lds2
=
reinterpret_cast
<
ADataType
*>
(
p_smem
);
const
ADataType
*
__restrict__
p_a_lds3
=
reinterpret_cast
<
ADataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
);
auto
a_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds0
,
a_lds_block_desc
);
auto
a_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds0
,
a_lds_block_desc
);
auto
a_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds1
,
a_lds_block_desc
);
auto
a_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds1
,
a_lds_block_desc
);
auto
a_lds_ld_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
2
,
a_lds_block_desc
);
auto
a_lds_ld_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
0
,
a_lds_block_desc
);
auto
a_lds_ld_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
3
,
a_lds_block_desc
);
auto
a_lds_ld_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
1
,
a_lds_block_desc
);
// B tile in LDS view
// B tile in LDS view
const
BDataType
*
__restrict__
p_b_lds0
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
*
2
);
const
BDataType
*
__restrict__
p_b_lds0
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem_0
)
+
a_lds_block_space_size_aligned
);
const
BDataType
*
__restrict__
p_b_lds1
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
*
2
+
b_lds_block_space_size_aligned
);
const
BDataType
*
__restrict__
p_b_lds1
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem_1
)
+
a_lds_block_space_size_aligned
);
const
BDataType
*
__restrict__
p_b_lds2
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
*
2
);
const
BDataType
*
__restrict__
p_b_lds3
=
reinterpret_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
*
2
+
b_lds_block_space_size_aligned
);
auto
b_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds0
,
b_lds_block_desc
);
auto
b_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds0
,
b_lds_block_desc
);
auto
b_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds1
,
b_lds_block_desc
);
auto
b_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds1
,
b_lds_block_desc
);
auto
b_lds_ld_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
2
,
b_lds_block_desc
);
auto
b_lds_ld_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
0
,
b_lds_block_desc
);
auto
b_lds_ld_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
3
,
b_lds_block_desc
);
auto
b_lds_ld_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
1
,
b_lds_block_desc
);
// A LDS tile window for store
// A LDS tile window for store
auto
a_lds_window0
=
make_tile_window_linear
(
auto
a_lds_window0
=
make_tile_window_linear
(
...
@@ -392,7 +387,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -392,7 +387,8 @@ struct GemmPipelineAGmemBGmemCRegV1
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
index_t
num_loop
,
void
*
p_smem
)
const
void
*
__restrict__
p_smem_0
,
void
*
__restrict__
p_smem_1
)
const
{
{
return
operator
()(
return
operator
()(
a_dram_block_window_tmp
,
a_dram_block_window_tmp
,
...
@@ -400,7 +396,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -400,7 +396,8 @@ struct GemmPipelineAGmemBGmemCRegV1
b_dram_block_window_tmp
,
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
num_loop
,
p_smem
);
p_smem_0
,
p_smem_1
);
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
928b6d1a
...
@@ -102,8 +102,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -102,8 +102,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
{
constexpr
index_t
smem_size_a
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
ADataType
)
*
constexpr
index_t
smem_size_a
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
16
)
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
16
);
*
2
;
return
smem_size_a
;
return
smem_size_a
;
}
}
...
@@ -111,8 +110,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -111,8 +110,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
{
constexpr
index_t
smem_size_b
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
BDataType
)
*
constexpr
index_t
smem_size_b
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
16
)
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
(),
16
);
*
2
;
return
smem_size_b
;
return
smem_size_b
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment