Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f1055b34
Unverified
Commit
f1055b34
authored
Feb 09, 2025
by
Muhammed Emin Ozturk
Committed by
GitHub
Feb 09, 2025
Browse files
Merge branch 'develop' into muozturk_sk_padding
parents
f84c49fa
a8c5bd9b
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
274 additions
and
27 deletions
+274
-27
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+13
-7
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+3
-2
include/ck_tile/ops/batched_transpose.hpp
include/ck_tile/ops/batched_transpose.hpp
+1
-1
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+2
-1
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
.../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+4
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+5
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+68
-9
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+0
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp
...rary/reference_tensor_operation/cpu/reference_mx_gemm.hpp
+178
-0
No files found.
include/ck_tile/host/check_err.hpp
View file @
f1055b34
...
...
@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
using
F8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
...
...
@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
...
...
@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
...
...
@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
using
F8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
ComputeDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
...
...
@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_any_of
<
OutDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
OutDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
...
...
@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_any_of
<
AccDataType
,
F8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
static_assert
(
is_any_of
<
AccDataType
,
F8
,
BF8
,
F16
,
BF16
,
F32
,
I8
,
I32
,
int
>::
value
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
...
...
@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
return
res
;
}
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
f1055b34
...
...
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
A
[
a_index
])
*
ck_tile
::
type_convert
<
AccDataType
>
(
B
[
b_index
]);
}
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
C
[
c_index
]
=
ck_tile
::
type_convert
<
CDataType
>
(
acc
)
;
}
}
...
...
include/ck_tile/ops/batched_transpose.hpp
View file @
f1055b34
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
f1055b34
...
...
@@ -77,6 +77,7 @@ struct CShuffleEpilogue
*
* @return The vector store size for C tensor.
*/
template
<
typename
ODataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
constexpr
index_t
MaxVectorStoreSize
=
16
;
...
...
@@ -142,7 +143,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D
<
kBlockSize
,
kMPerIteration
,
kNPerIteration
,
GetVectorSizeC
(),
GetVectorSizeC
<
ODataType
>
(),
tile_distribution_pattern
::
thread_raked
>
;
constexpr
auto
dram_tile_distribution
=
TileEncodingPattern
::
Make2DStaticTileDistribution
();
...
...
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
View file @
f1055b34
...
...
@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ??
static
constexpr
index_t
InterWaveSchedulingMacClusters
=
1
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static
constexpr
index_t
KPack
=
8
;
static
constexpr
index_t
KPerThread
=
KIterPerWarp
*
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
f1055b34
...
...
@@ -159,7 +159,7 @@ struct GemmKernel
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
{
if
constexpr
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
if
constexpr
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
)
{
if
(
kargs
.
k_batch
!=
1
)
...
...
@@ -240,7 +240,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
N
%
EpiloguePipeline
::
GetVectorSizeC
()
!=
0
)
if
(
kargs
.
N
%
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
!=
0
)
{
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -255,7 +255,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
M
%
EpiloguePipeline
::
GetVectorSizeC
()
!=
0
)
if
(
kargs
.
M
%
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
!=
0
)
{
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -321,7 +321,7 @@ struct GemmKernel
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
EpiloguePipeline
::
GetVectorSizeC
()
>
{},
number
<
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
>
{},
number
<
1
>
{});
}
else
...
...
@@ -519,7 +519,7 @@ struct GemmKernel
{
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if
constexpr
(
!
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
if
constexpr
(
!
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>
()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
f1055b34
...
...
@@ -3,6 +3,9 @@
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
...
...
@@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST
static
std
::
string
Print
()
{
constexpr
index_t
MPerXDL
=
BlockGemm
::
WarpGemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
WarpGemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
auto
str
=
std
::
stringstream
{};
str
<<
"A/B vector size: "
<<
GetVectorSizeA
()
<<
", "
<<
GetVectorSizeB
()
<<
"
\n
"
<<
"A/B LDS read/write width: "
<<
A_LDS_Read_Width
<<
", "
<<
B_LDS_Read_Width
<<
"
\n
"
<<
"A/B buffer load inst: "
<<
A_Buffer_Load_Inst_Num
<<
", "
<<
B_Buffer_Load_Inst_Num
<<
"
\n
"
<<
"A/B LDS write inst: "
<<
A_LDS_Write_Inst_Num
<<
", "
<<
B_LDS_Write_Inst_Num
<<
"
\n
"
<<
"A/B LDS read inst: "
<<
A_LDS_Read_Inst_Num
<<
", "
<<
B_LDS_Read_Inst_Num
<<
"
\n
"
<<
"C MFMA inst: "
<<
C_MFMA_Inst_Num
<<
"
\n
"
<<
"KPack: "
<<
BlockGemm
::
Traits
::
KPack
<<
"
\n
"
<<
"PrefetchStages: "
<<
PrefetchStages
<<
"
\n
"
;
return
str
.
str
();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
...
...
@@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I0
{})
;
constexpr
index_t
NPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I1
{})
;
constexpr
index_t
KPerXDL
=
BlockGemm
Shape
::
WarpTile
::
at
(
I2
{})
;
constexpr
index_t
MPerXDL
=
BlockGemm
::
Warp
Gemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
Warp
Gemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
f1055b34
...
...
@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp
0 → 100644
View file @
f1055b34
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
ReferenceMXGemm
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
a_m_k_
{
a_m_k
},
a_m_kblock_scales_
{
a_m_kblock_scales
},
b_k_n_
{
b_k_n
},
b_kblock_n_scales_
{
b_kblock_n_scales
},
c_m_n_
{
c_m_n
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_m_k_
;
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales_
;
const
Tensor
<
BDataType
>&
b_k_n_
;
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales_
;
Tensor
<
CDataType
>&
c_m_n_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceMXGemm
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
using
GemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ComputeTypeA
,
ComputeTypeB
,
CDataType
,
AccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputeTypeA
,
ComputeTypeB
>
;
Tensor
<
ComputeTypeA
>
a_m_k_scaled
(
arg
.
a_m_k_
.
mDesc
);
Tensor
<
ComputeTypeB
>
b_k_n_scaled
(
arg
.
b_k_n_
.
mDesc
);
const
auto
M
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
arg
.
b_k_n_
.
mDesc
.
GetLengths
()[
1
];
const
auto
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
const
auto
SCALE_BLOCK
=
K
/
arg
.
a_m_kblock_scales_
.
mDesc
.
GetLengths
()[
1
];
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
a_m_k_scaled
(
m
,
k
)
=
type_convert
<
ComputeTypeA
>
(
arg
.
a_m_k_
(
m
,
k
))
*
type_convert
<
ComputeTypeA
>
(
arg
.
a_m_kblock_scales_
(
m
,
k
/
SCALE_BLOCK
));
}
}
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
b_k_n_scaled
(
k
,
n
)
=
type_convert
<
ComputeTypeB
>
(
arg
.
b_k_n_
(
k
,
n
))
*
type_convert
<
ComputeTypeB
>
(
arg
.
b_kblock_n_scales_
(
k
/
SCALE_BLOCK
,
n
));
}
}
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k_scaled
,
b_k_n_scaled
,
arg
.
c_m_n_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
ref_invoker
.
Run
(
ref_argument
);
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
ScaleDataType
>&
a_m_kblock_scales
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
Tensor
<
ScaleDataType
>&
b_kblock_n_scales
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
a_m_k
,
a_m_kblock_scales
,
b_k_n
,
b_kblock_n_scales
,
c_m_n
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceMXGemm"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment