Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9da21f99
Commit
9da21f99
authored
Feb 06, 2025
by
Andriy Roshchenko
Browse files
Merging origin/lwpck-2663 into andriy/lwpck-2829
parents
17617a15
76b33666
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
451 additions
and
106 deletions
+451
-106
example/67_gemm_microscaling/gemm_mx_common.hpp
example/67_gemm_microscaling/gemm_mx_common.hpp
+22
-34
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+236
-55
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp
...rary/reference_tensor_operation/cpu/reference_mx_gemm.hpp
+178
-0
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+14
-16
No files found.
example/67_gemm_microscaling/gemm_mx_common.hpp
View file @
9da21f99
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_
mx_
gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/fill.hpp"
...
@@ -313,40 +313,27 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -313,40 +313,27 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"Computing GEMM on host..."
<<
std
::
endl
;
std
::
cout
<<
"Computing GEMM on host..."
<<
std
::
endl
;
}
}
Tensor
<
CDataType
>
c
({
M
,
N
});
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceMXGemm
<
ADataType
,
Tensor
<
float
>
a
({
M
,
K
});
BDataType
,
Tensor
<
float
>
b
({
K
,
N
});
CDataType
,
AccDataType
,
for
(
int
m
=
0
;
m
<
M
;
m
++
)
float
,
{
PassThrough
,
for
(
int
k
=
0
;
k
<
K
;
k
++
)
PassThrough
,
{
PassThrough
,
a
(
m
,
k
)
=
ck
::
type_convert
<
float
>
(
a_m_k
(
m
,
k
))
*
float
,
ck
::
type_convert
<
float
>
(
a_m_k_scale
(
m
,
k
/
Scale_Block_K
));
float
>
;
}
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
b
(
k
,
n
)
=
ck
::
type_convert
<
float
>
(
b_k_n
(
k
,
n
))
*
ck
::
type_convert
<
float
>
(
b_k_n_scale
(
k
/
Scale_Block_K
,
n
));
}
}
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
CShuffleDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
ref_gemm
.
MakeArgument
(
a
,
b
,
c
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
a_m_k_scale
,
b_k_n
,
b_k_n_scale
,
c_m_n_host_result
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
@@ -364,8 +351,9 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -364,8 +351,9 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
<<
((
res_verified
)
?
" (PASSED!)"
:
" (FAILED!)"
)
<<
std
::
endl
;
<<
((
res_verified
)
?
" (PASSED!)"
:
" (FAILED!)"
)
<<
std
::
endl
;
}
}
res_verified
=
res_verified
&&
res_verified
=
res_verified
&&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c
,
"Error: Incorrect results!"
);
c_m_n_host_result
,
"Error: Incorrect results!"
);
if
(
config
.
verbosity
>
0
&&
res_verified
)
if
(
config
.
verbosity
>
0
&&
res_verified
)
std
::
cout
<<
"Done."
<<
std
::
endl
;
std
::
cout
<<
"Done."
<<
std
::
endl
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
9da21f99
...
@@ -90,7 +90,7 @@ struct BaseGemmPipelineAgBgCrMem
...
@@ -90,7 +90,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineA
GmemBGmemCRegV1Default
Policy
>
template
<
typename
Problem
,
typename
Policy
=
Universal
GemmPipelineA
gBgCr
Policy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
...
@@ -165,11 +165,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -165,11 +165,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
constexpr
bool
is_a_col_major
=
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// Definitions of all needed tiles
...
@@ -213,25 +224,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -213,25 +224,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// Gemm pipeline start
// prefetch
// prefetch
// global read 0
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
if
constexpr
(
is_a_col_major
)
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
});
// main body
// main body
...
@@ -247,19 +292,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -247,19 +292,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
block_sync_lds
();
Base
::
LocalPrefill
(
if
constexpr
(
is_a_col_major
)
a_copy_lds_window
,
{
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
a_element_func
);
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
Base
::
LocalPrefill
(
transpose_tile2d
(
b_copy_lds_window
,
a_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
b_element_func
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
});
i
+=
PrefetchStages
;
i
+=
PrefetchStages
;
...
@@ -275,12 +346,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -275,12 +346,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
if
constexpr
(
is_a_col_major
)
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
{
a_element_func
);
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Base
::
LocalPrefill
(
b_copy_lds_window
,
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
b_element_func
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
}
});
});
block_sync_lds
();
block_sync_lds
();
...
@@ -352,11 +443,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -352,11 +443,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
constexpr
bool
is_a_col_major
=
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// Definitions of all needed tiles
...
@@ -400,25 +502,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -400,25 +502,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// Gemm pipeline start
// prefetch
// prefetch
// global read 0
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
if
constexpr
(
is_a_col_major
)
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
});
// main body
// main body
...
@@ -432,19 +567,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -432,19 +567,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
if
constexpr
(
is_a_col_major
)
a_copy_lds_window
,
{
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
a_element_func
);
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
Base
::
LocalPrefill
(
transpose_tile2d
(
b_copy_lds_window
,
a_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
b_element_func
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
});
i
+=
PrefetchStages
;
i
+=
PrefetchStages
;
...
@@ -457,12 +618,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -457,12 +618,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
if
constexpr
(
is_a_col_major
)
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
{
a_element_func
);
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Base
::
LocalPrefill
(
b_copy_lds_window
,
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
b_element_func
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
}
});
});
block_sync_lds
();
block_sync_lds
();
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
9da21f99
...
@@ -519,7 +519,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -519,7 +519,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
k
N
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
k
M
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp
0 → 100644
View file @
9da21f99
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
ReferenceMXGemm
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
a_m_k_
{
a_m_k
},
a_m_kblock_scales_
{
a_m_kblock_scales
},
b_k_n_
{
b_k_n
},
b_kblock_n_scales_
{
b_kblock_n_scales
},
c_m_n_
{
c_m_n
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_m_k_
;
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales_
;
const
Tensor
<
BDataType
>&
b_k_n_
;
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales_
;
Tensor
<
CDataType
>&
c_m_n_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceMXGemm
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
using
GemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ComputeTypeA
,
ComputeTypeB
,
CDataType
,
AccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputeTypeA
,
ComputeTypeB
>
;
Tensor
<
ComputeTypeA
>
a_m_k_scaled
(
arg
.
a_m_k_
.
mDesc
);
Tensor
<
ComputeTypeB
>
b_k_n_scaled
(
arg
.
b_k_n_
.
mDesc
);
const
auto
M
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
arg
.
b_k_n_
.
mDesc
.
GetLengths
()[
1
];
const
auto
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
const
auto
SCALE_BLOCK
=
K
/
arg
.
a_m_kblock_scales_
.
mDesc
.
GetLengths
()[
1
];
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
a_m_k_scaled
(
m
,
k
)
=
type_convert
<
ComputeTypeA
>
(
arg
.
a_m_k_
(
m
,
k
))
*
type_convert
<
ComputeTypeA
>
(
arg
.
a_m_kblock_scales_
(
m
,
k
/
SCALE_BLOCK
));
}
}
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
b_k_n_scaled
(
k
,
n
)
=
type_convert
<
ComputeTypeB
>
(
arg
.
b_k_n_
(
k
,
n
))
*
type_convert
<
ComputeTypeB
>
(
arg
.
b_kblock_n_scales_
(
k
/
SCALE_BLOCK
,
n
));
}
}
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k_scaled
,
b_k_n_scaled
,
arg
.
c_m_n_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
ref_invoker
.
Run
(
ref_argument
);
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
a_m_k
,
a_m_kblock_scales
,
b_k_n
,
b_kblock_n_scales
,
c_m_n
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceMXGemm"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
9da21f99
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <tuple>
...
@@ -14,28 +14,26 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
...
@@ -14,28 +14,26 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
// ck_tile::GemmPipelineScheduler::Interwave>;
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
>
;
>
;
// clang-format on
// clang-format on
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment