Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9f65d608
Commit
9f65d608
authored
Feb 15, 2025
by
Bartlomiej Kocot
Browse files
[CK TILE] Gemm pk_int4_t permute B
parent
0328b06e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
205 additions
and
101 deletions
+205
-101
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+65
-0
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+40
-5
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+37
-83
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+55
-12
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
+8
-1
No files found.
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
9f65d608
...
@@ -35,6 +35,71 @@
...
@@ -35,6 +35,71 @@
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
#endif
struct
GemmBasicConfig
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
static
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
static
constexpr
ck_tile
::
index_t
N_Tile
=
32
;
static
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
static
constexpr
ck_tile
::
index_t
M_Warp
=
4
;
static
constexpr
ck_tile
::
index_t
N_Warp
=
1
;
static
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
static
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
static
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
static
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
static
constexpr
bool
DoubleSmemBuffer
=
false
;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
static
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
static
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
static
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
static
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
static
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
static
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
static
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
static
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
static
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
static
constexpr
bool
DoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
static
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
static
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
static
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
static
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
static
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
static
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
static
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
static
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
static
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
static
constexpr
bool
DoubleSmemBuffer
=
true
;
#endif
static
constexpr
bool
kPadM
=
false
;
static
constexpr
bool
kPadN
=
false
;
static
constexpr
bool
kPadK
=
false
;
static
constexpr
bool
PermuteA
=
false
;
static
constexpr
bool
PermuteB
=
false
;
static
constexpr
bool
TransposeC
=
false
;
static
constexpr
int
kBlockPerCu
=
1
;
static
constexpr
ck_tile
::
index_t
TileParitionerGroupNum
=
8
;
static
constexpr
ck_tile
::
index_t
TileParitionerM01
=
4
;
};
template
<
typename
ADataType
,
typename
BDataType
=
ADataType
,
typename
CDataType
=
ADataType
>
template
<
typename
ADataType
,
typename
BDataType
=
ADataType
,
typename
CDataType
=
ADataType
>
struct
GemmBasicTypeConfig
;
struct
GemmBasicTypeConfig
;
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
9f65d608
...
@@ -29,8 +29,32 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
...
@@ -29,8 +29,32 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
// Use higher threshold
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
}
template
<
typename
Tensor
>
template
<
typename
Tensor
>
void
permute_tensor_b
(
Tensor
&
tensor
)
void
permute_tensor_b
(
Tensor
&
tensor
)
{
const
ck_tile
::
index_t
K
=
tensor
.
get_length
(
0
);
const
ck_tile
::
index_t
N
=
tensor
.
get_length
(
1
);
const
ck_tile
::
index_t
K1
=
GemmBasicConfig
::
K_Tile
;
const
ck_tile
::
index_t
K0
=
K
/
GemmBasicConfig
::
K_Tile
;
Tensor
tensor_copy
=
tensor
;
// int K0, N, K1
for
(
int
j
=
0
;
j
<
K0
;
j
++
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
jj
=
0
;
jj
<
K1
;
jj
++
)
{
tensor
(
j
*
N
*
K1
+
i
*
K1
+
jj
)
=
tensor_copy
(
i
*
K
+
(
j
*
K1
+
jj
));
}
}
}
}
template
<
typename
Tensor
>
void
permute_vectors_i4x4_b
(
Tensor
&
tensor
)
{
{
const
ck_tile
::
index_t
K
=
tensor
.
get_length
(
0
);
const
ck_tile
::
index_t
K
=
tensor
.
get_length
(
0
);
const
ck_tile
::
index_t
N
=
tensor
.
get_length
(
1
);
const
ck_tile
::
index_t
N
=
tensor
.
get_length
(
1
);
...
@@ -183,8 +207,8 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -183,8 +207,8 @@ int run_gemm_example_with_layouts(int argc,
if
(
init_method
==
0
)
if
(
init_method
==
0
)
{
{
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
1
.
f
,
1
.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5
.
f
,
5
.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
1
.
f
,
1
.
f
}(
b_k_n
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5
.
f
,
5
.
f
}(
b_k_n
);
}
}
else
if
(
init_method
==
1
)
else
if
(
init_method
==
1
)
{
{
...
@@ -206,18 +230,29 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -206,18 +230,29 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
()
);
static_assert
(
!
GemmBasicConfig
::
PermuteA
,
"Not implemented"
);
if
constexpr
(
std
::
is_same_v
<
BDataType
,
ck_tile
::
pk_int4_t
>
)
if
constexpr
(
std
::
is_same_v
<
BDataType
,
ck_tile
::
pk_int4_t
>
)
{
{
// Permute data for device implementation
// Permute
vector pk_i4x4
data for device implementation
ck_tile
::
HostTensor
<
BDataType
>
b_k_n_dev
=
b_k_n
;
ck_tile
::
HostTensor
<
BDataType
>
b_k_n_dev
=
b_k_n
;
permute_tensor_b
(
b_k_n_dev
);
if
constexpr
(
GemmBasicConfig
::
PermuteB
)
{
permute_tensor_b
(
b_k_n_dev
);
}
permute_vectors_i4x4_b
(
b_k_n_dev
);
b_k_n_dev_buf
.
ToDevice
(
b_k_n_dev
.
data
());
b_k_n_dev_buf
.
ToDevice
(
b_k_n_dev
.
data
());
}
}
else
else
{
{
if
constexpr
(
GemmBasicConfig
::
PermuteB
)
{
std
::
cout
<<
"Permute for this DataType is not implemented."
<<
std
::
endl
;
return
false
;
}
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
}
}
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
());
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
9f65d608
...
@@ -21,90 +21,42 @@ template <typename ADataType,
...
@@ -21,90 +21,42 @@ template <typename ADataType,
typename
CLayout
>
typename
CLayout
>
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using
GemmShape
=
ck_tile
::
TileGemmShape
<
// Memory friendly for Interwave scheduler
ck_tile
::
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
sequence
<
GemmBasicConfig
::
M_Tile
,
GemmBasicConfig
::
N_Tile
,
GemmBasicConfig
::
K_Tile
>
,
constexpr
ck_tile
::
index_t
N_Tile
=
32
;
ck_tile
::
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
sequence
<
GemmBasicConfig
::
M_Warp
,
GemmBasicConfig
::
N_Warp
,
GemmBasicConfig
::
K_Warp
>
,
ck_tile
::
sequence
<
GemmBasicConfig
::
M_Warp_Tile
,
constexpr
ck_tile
::
index_t
M_Warp
=
4
;
GemmBasicConfig
::
N_Warp_Tile
,
constexpr
ck_tile
::
index_t
N_Warp
=
1
;
GemmBasicConfig
::
K_Warp_Tile
>
,
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
GemmBasicConfig
::
PermuteA
,
GemmBasicConfig
::
PermuteB
>
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
using
TilePartitioner
=
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
ck_tile
::
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
GemmBasicConfig
::
TileParitionerGroupNum
,
GemmBasicConfig
::
TileParitionerM01
>
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#endif
using
Traits
=
ck_tile
::
TileGemmTraits
<
GemmBasicConfig
::
kPadM
,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
GemmBasicConfig
::
kPadN
,
// Compute friendly for Intrawave scheduler
GemmBasicConfig
::
kPadK
,
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
ALayout
,
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
BLayout
,
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
GemmBasicConfig
::
kPadM
,
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
GemmBasicConfig
::
kPadN
,
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
GemmBasicConfig
::
kPadK
,
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
GemmBasicConfig
::
DoubleSmemBuffer
,
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
true
;
#endif
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
TileParitionerGroupNum
=
8
;
constexpr
ck_tile
::
index_t
TileParitionerM01
=
4
;
// ===============================================
using
GemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
DoubleSmemBuffer
,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
TransposeC
>
;
GemmBasicConfig
::
TransposeC
>
;
using
GemmPipelineProblem
=
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
using
BaseGemmPipeline
=
UNIVERSAL_GEMM_PIPELINE
<
GemmPipelineProblem
>
;
using
BaseGemmPipeline
=
UNIVERSAL_GEMM_PIPELINE
<
GemmPipelineProblem
>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
GemmBasicConfig
::
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
GemmBasicConfig
::
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
...
@@ -133,11 +85,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -133,11 +85,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
GemmPipelineProblem
::
kBlockSize
,
GemmPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
GemmBasicConfig
::
M_Warp
,
N_Warp
,
GemmBasicConfig
::
N_Warp
,
M_Warp_Tile
,
GemmBasicConfig
::
M_Warp_Tile
,
N_Warp_Tile
,
GemmBasicConfig
::
N_Warp_Tile
,
K_Warp_Tile
,
GemmBasicConfig
::
K_Warp_Tile
,
UniversalGemmProblem
::
TransposeC
>>
;
UniversalGemmProblem
::
TransposeC
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
...
@@ -158,8 +110,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -158,8 +110,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
<<
std
::
endl
;
<<
std
::
endl
;
}
}
ave_time
=
ck_tile
::
launch_kernel
(
ave_time
=
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
GemmBasicConfig
::
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
return
ave_time
;
};
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
9f65d608
...
@@ -279,6 +279,7 @@ struct GemmKernel
...
@@ -279,6 +279,7 @@ struct GemmKernel
const
GemmKernelArgs
&
kargs
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
)
const
SplitKBatchOffset
&
splitk_batch_offset
)
{
{
static_assert
(
!
TilePartitioner
::
BlockGemmShape
::
PermuteA
,
"Not implemented!"
);
const
auto
&
a_tensor_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
@@ -303,21 +304,63 @@ struct GemmKernel
...
@@ -303,21 +304,63 @@ struct GemmKernel
const
auto
&
b_tensor_view
=
[
&
]()
{
const
auto
&
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
if
constexpr
(
TilePartitioner
::
BlockGemmShape
::
PermuteB
)
b_ptr
,
{
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
N
),
const
index_t
K1
=
TilePartitioner
::
BlockGemmShape
::
kK
;
make_tuple
(
kargs
.
stride_B
,
1
),
const
index_t
K0
=
number
<
GemmPipeline
::
GetVectorSizeB
()
>
{},
splitk_batch_offset
.
splitted_k
/
TilePartitioner
::
BlockGemmShape
::
kK
;
number
<
1
>
{});
const
auto
b_k0_n_k1_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
kargs
.
N
,
K1
),
make_tuple
(
kargs
.
N
*
K1
,
K1
,
I1
),
number
<
GemmPipeline
::
GetVectorSizeB
()
>
{},
number
<
1
>
{});
const
auto
b_n_k_desc
=
transform_tensor_descriptor
(
b_k0_n_k1_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
kargs
.
N
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
make_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
b_n_k_desc
);
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
GetVectorSizeB
()
>
{},
number
<
1
>
{});
}
}
}
else
else
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
if
constexpr
(
TilePartitioner
::
BlockGemmShape
::
PermuteB
)
b_ptr
,
{
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
const
index_t
K1
=
TilePartitioner
::
BlockGemmShape
::
kK
;
make_tuple
(
kargs
.
stride_B
,
1
),
const
index_t
K0
=
number
<
GemmPipeline
::
GetVectorSizeB
()
>
{},
splitk_batch_offset
.
splitted_k
/
TilePartitioner
::
BlockGemmShape
::
kK
;
number
<
1
>
{});
const
auto
b_k0_n_k1_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
kargs
.
N
,
K1
),
make_tuple
(
kargs
.
N
*
K1
,
K1
,
I1
),
number
<
GemmPipeline
::
GetVectorSizeB
()
>
{},
number
<
1
>
{});
const
auto
b_n_k_desc
=
transform_tensor_descriptor
(
b_k0_n_k1_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
kargs
.
N
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
make_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
b_n_k_desc
);
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
GetVectorSizeB
()
>
{},
number
<
1
>
{});
}
}
}
}();
}();
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
View file @
9f65d608
...
@@ -8,7 +8,11 @@
...
@@ -8,7 +8,11 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
BlockTile_
,
typename
BlockWarps_
,
typename
WarpTile_
>
template
<
typename
BlockTile_
,
typename
BlockWarps_
,
typename
WarpTile_
,
bool
PermuteA_
=
false
,
bool
PermuteB_
=
false
>
struct
TileGemmShape
struct
TileGemmShape
{
{
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
...
@@ -21,6 +25,9 @@ struct TileGemmShape
...
@@ -21,6 +25,9 @@ struct TileGemmShape
static
constexpr
index_t
kN
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kN
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kK
=
BlockTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kK
=
BlockTile
::
at
(
number
<
2
>
{});
static
constexpr
bool
PermuteA
=
PermuteA_
;
static
constexpr
bool
PermuteB
=
PermuteB_
;
CK_TILE_HOST
static
std
::
string
GetName
()
CK_TILE_HOST
static
std
::
string
GetName
()
{
{
// clang-format off
// clang-format off
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment