Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
67ab3896
".github/workflows/release-docker-deepep.yml" did not exist on "18317ddc13bc403749fe9f99ef5726796f855b0e"
Commit
67ab3896
authored
Jan 08, 2025
by
Aleksander Dudek
Browse files
Merge branch 'develop' into gemm_getname
parents
8adaf418
d5c8a334
Changes
100
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1463 additions
and
204 deletions
+1463
-204
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+45
-14
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+17
-0
include/ck_tile/ops/norm_reduce.hpp
include/ck_tile/ops/norm_reduce.hpp
+3
-4
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
+78
-48
include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
..._tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
+7
-2
include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp
include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp
+0
-0
include/ck_tile/ref/naive_attention.hpp
include/ck_tile/ref/naive_attention.hpp
+291
-131
library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp
...ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp
+91
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp
.../tensor_operation_instance/gpu/gemm_universal_streamk.hpp
+500
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+4
-4
library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt
...tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt
+10
-0
library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp
...16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp
+105
-0
library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
...scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt
...ration_instance/gpu/gemm_universal_streamk/CMakeLists.txt
+38
-1
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp
...ce_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp
+91
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp
...streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp
+30
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instance.cpp
...treamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instance.cpp
+30
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instance.cpp
...eamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instance.cpp
+30
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instance.cpp
...reamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instance.cpp
+30
-0
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instance.cpp
...reamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instance.cpp
+31
-0
No files found.
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
67ab3896
...
...
@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XBiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XBiasDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
...
...
@@ -36,6 +37,8 @@ struct Layernorm2dFwdPipelineTwoPass
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kXbias
=
Problem
::
Traits
::
kXbias
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
...
@@ -53,6 +56,7 @@ struct Layernorm2dFwdPipelineTwoPass
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
XBiasWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
...
...
@@ -64,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
XBiasWindow
&
x_bias_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
...
...
@@ -77,8 +82,11 @@ struct Layernorm2dFwdPipelineTwoPass
void
*
smem
,
Epilogue
)
const
{
static_assert
(
kWelford
==
true
,
"2 pass only supports welford merge"
);
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
x_bias_window
=
make_tile_window
(
x_bias_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
beta_window
=
make_tile_window
(
...
...
@@ -102,24 +110,35 @@ struct Layernorm2dFwdPipelineTwoPass
int
max_count
=
(
num_n_tile_iteration
-
1
)
*
count_per_iter
+
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
last_iter_n
);
auto
block_
welford
=
Policy
::
template
GetBlock
Welford
<
Problem
>();
auto
block_
welford
_sync
=
Policy
::
template
GetBlock
Welford
Sync
<
Problem
>();
auto
block_
welford
_cross_warp_sync
=
Policy
::
template
GetBlock
Welford
CrossWarpSync
<
Problem
>();
auto
block_
norm_reduce
=
Policy
::
template
GetBlock
NormReduce
<
Problem
>();
auto
block_
norm_reduce
_sync
=
Policy
::
template
GetBlock
NormReduce
Sync
<
Problem
>();
auto
block_
norm_reduce
_cross_warp_sync
=
Policy
::
template
GetBlock
NormReduce
CrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
)));
auto
mean
=
block_
welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_
welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
mean
=
block_
norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_
norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
x_bias
=
load_tile
(
x_bias_window
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
x_bias_window
,
{
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kXbias
==
Layernorm2dXBiasEnum
::
ADD_BIAS
)
{
sweep_tile
(
x
,
[
&
](
auto
idx
)
{
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_bias
[
j_idx
])
+
acc
(
idx
);
});
}
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
...
...
@@ -133,11 +152,11 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
block_
welford
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
block_
norm_reduce
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
}
block_
welford
_sync
(
mean
,
var
,
cur_count
);
block_
welford
_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_
norm_reduce
_sync
(
mean
,
var
,
cur_count
);
block_
norm_reduce
_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
// compute inv-std
...
...
@@ -165,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_bias_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
...
...
@@ -172,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
x_bias
=
load_tile
(
x_bias_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kXbias
==
Layernorm2dXBiasEnum
::
ADD_BIAS
)
{
sweep_tile
(
x
,
[
&
](
auto
idx
)
{
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_bias
[
j_idx
])
+
acc
(
idx
);
});
}
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
...
...
@@ -207,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_bias_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
View file @
67ab3896
...
...
@@ -7,6 +7,19 @@
namespace
ck_tile
{
enum
class
Layernorm2dXBiasEnum
{
NO_BIAS
=
0
,
// add bias before fused add
ADD_BIAS
=
1
,
};
// clang-format off
template
<
Layernorm2dXBiasEnum
>
struct
Layernorm2dXBiasEnumName
;
template
<
>
struct
Layernorm2dXBiasEnumName
<
Layernorm2dXBiasEnum
::
NO_BIAS
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Layernorm2dXBiasEnumName
<
Layernorm2dXBiasEnum
::
ADD_BIAS
>
{
static
constexpr
const
char
*
name
=
"xbias"
;
};
// clang-format on
enum
class
Layernorm2dFusedAddEnum
{
NO_ADD
=
0
,
...
...
@@ -40,7 +53,9 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
template
<
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kFastFDiv_
,
bool
kWelford_
,
bool
kTwoPass_
,
Layernorm2dXBiasEnum
kXbias_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedQuantEnum
kFusedQuant_
>
struct
Layernorm2dFwdTraits
...
...
@@ -48,7 +63,9 @@ struct Layernorm2dFwdTraits
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Layernorm2dXBiasEnum
kXbias
=
kXbias_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
};
...
...
include/ck_tile/ops/
welford
.hpp
→
include/ck_tile/ops/
norm_reduce
.hpp
View file @
67ab3896
...
...
@@ -3,9 +3,8 @@
#pragma once
#include "ck_tile/ops/
welford
/block/block_
welford
.hpp"
#include "ck_tile/ops/
welford
/block/block_
welford
_problem.hpp"
#include "ck_tile/ops/
welford
/thread/thread_welford.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
_problem.hpp"
#include "ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/
welford
/block/block_
welford
.hpp
→
include/ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp
View file @
67ab3896
...
...
@@ -4,22 +4,23 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/
welford
/thread/thread_welford.hpp"
#include "ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
struct
Block
NormReduce
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
XDataType
=
typename
Problem
::
XDataType
;
using
ComputeDataType
=
typename
Problem
::
ComputeDataType
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
CK_TILE_DEVICE
constexpr
Block
Welford
()
{}
CK_TILE_DEVICE
constexpr
Block
NormReduce
()
{}
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN
welford
will have different
// max_count_ is depend on caller, eg: naive and splitN
norm_reduce
will have different
// calculation of max_count_
// -> use block_welford_calculate_max_count to compute
template
<
typename
XDistributedTensor_
,
...
...
@@ -40,18 +41,24 @@ struct BlockWelford
if
(
cur_count_
<
max_count_
)
{
++
cur_count_
;
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
constexpr
auto
out_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
,
cur_count_
,
constant
<
kFastFDiv
>
{});
if
(
kWelford
)
{
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
,
cur_count_
,
constant
<
kFastFDiv
>
{});
}
else
{
mean_tensor
(
out_dstr_idx
)
+=
x
;
var_tensor
(
out_dstr_idx
)
+=
x
*
x
;
}
});
}
});
...
...
@@ -91,10 +98,11 @@ struct BlockWelford
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
Sync
struct
Block
NormReduce
Sync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
...
...
@@ -152,36 +160,48 @@ struct BlockWelfordSync
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// welford merge
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
if
(
kWelford
)
{
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// norm_reduce merge
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
}
else
{
v_local_mean
+=
v_remote_mean
;
v_local_var
+=
v_remote_var
;
}
});
}
});
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
count
=
v_local_count
;
if
(
kWelford
)
{
count
=
v_local_count
;
}
});
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
CrossWarpSync
struct
Block
NormReduce
CrossWarpSync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
using
smem_dtype
=
std
::
conditional_t
<
kWelford
,
fp32x4_t
,
fp32x2_t
>
;
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
static
constexpr
index_t
GetReduceWarps
()
...
...
@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
// Note: we always pack everything into fp32x4
fp32x4_t
*
smem_ptr
=
reinterpret_cast
<
fp32x4_t
*>
(
smem
);
smem_dtype
*
smem_ptr
=
reinterpret_cast
<
smem_dtype
*>
(
smem
);
const
index_t
lane_id
=
get_lane_id
();
const
index_t
warp_id
=
get_warp_id
();
constexpr
auto
num_reduce_warps
=
GetReduceWarps
<
MeanDistributedTensor_
>
();
...
...
@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync
if
(
lane_id
==
0
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
fp32x4_t
local_scratch_
;
smem_dtype
local_scratch_
;
local_scratch_
[
0
]
=
bit_cast
<
float
>
(
mean_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
1
]
=
bit_cast
<
float
>
(
var_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
if
(
kWelford
)
{
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
}
smem_ptr
[
smem_offset
+
i
*
num_warps
]
=
local_scratch_
;
});
}
...
...
@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync
// load from smem. here we let everythread to do compute :)
index_t
local_warp_id
=
warp_id
/
num_reduce_warps
;
index_t
local_smem_os
=
local_warp_id
*
num_reduce_warps
;
fp32x4_t
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
smem_dtype
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
]
=
...
...
@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_count
=
bit_cast
<
int
>
(
v_local
[
2
]);
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
int
v_local_count
=
kWelford
?
bit_cast
<
int
>
(
v_local
[
2
])
:
0
;
// further reduce mean/var
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
const
fp32x4_t
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
const
smem_dtype
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
if
(
kWelford
)
{
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
}
else
{
v_local_mean
+=
v_remote_mean
;
v_local_var
+=
v_remote_var
;
}
});
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
count
=
v_local_count
;
if
(
kWelford
)
count
=
v_local_count
;
});
}
};
...
...
include/ck_tile/ops/
welford
/block/block_
welford
_problem.hpp
→
include/ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
_problem.hpp
View file @
67ab3896
...
...
@@ -7,13 +7,18 @@
namespace
ck_tile
{
template
<
typename
XDataType_
,
typename
ComputeDataType_
,
typename
BlockShape_
,
bool
kFastFDiv_
>
struct
BlockWelfordProblem
template
<
typename
XDataType_
,
typename
ComputeDataType_
,
typename
BlockShape_
,
bool
kFastFDiv_
,
bool
kWelford_
>
struct
BlockNormReduceProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
};
}
// namespace ck_tile
include/ck_tile/ops/
welford
/thread/thread_welford.hpp
→
include/ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp
View file @
67ab3896
File moved
include/ck_tile/ref/naive_attention.hpp
View file @
67ab3896
...
...
@@ -13,13 +13,18 @@ namespace ck_tile {
enum
class
naive_attention_layout_enum
{
BSHD
,
// [batch, seqlen, nhead, hdim]
BHSD
,
// [batch, nhead, seqlen, hdim]
BS3HD
,
// [batch, nhead, 3, seqlen, hdim], used when qkv are packed
PHSD
,
// [pages, nhead, page_size, hdim]
DEFAULT
,
// maybe this tensor is not used, set some irrelevant value
BSHD
,
// [batch, seqlen, nhead, hdim]
BHSD
,
// [batch, nhead, seqlen, hdim]
BS3HD
,
// [batch, nhead, 3, seqlen, hdim], used when qkv are packed
PHSD
,
// [pages, nhead, page_size, hdim]
// PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
PHDSX
,
// [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
PHDS
,
// [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
// scale layout used for dynamic dequant
SCALE_HS
,
// [nhead, tokens] or [nhead, tokens-per-group], nhe KVCache quant
SCALE_SH
,
// [tokens, nhead]
};
// will used to specialize kernel variation
...
...
@@ -30,6 +35,15 @@ enum class naive_attention_variation_enum
DECODE_PAGED
,
// decode attn, where kv token from another buffer called kvcache
};
enum
class
naive_attention_quant_algo
{
NO
=
0
,
KV_8BIT_PERHEAD
=
1
,
// FP8/INT8 quant for KVCache, per-token quant
// [num_tokens, nhead, hdim] -> [nhead, num_tokens]
KV_8BIT_PERTOKEN
=
2
,
};
// TODO: for simplicity, this will be used as host/device arg
struct
naive_attention_fwd_args
{
...
...
@@ -40,7 +54,8 @@ struct naive_attention_fwd_args
void
*
context_len_ptr
;
// [batch] used when seqlen kv come from a pointer(each element is a
// number, not cumsum)
void
*
page_table_ptr
;
// [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
void
*
kvscale_ptr
;
// [nhead, 2(kv), hdim] used for kvcache dequant
void
*
kscale_ptr
;
// [nhead, max_kv_tokens] used for kvcache dequant
void
*
vscale_ptr
;
// [nhead, max_kv_tokens] used for kvcache dequant
float
scale_s
;
int
hdim
;
int
hdim_v
;
// could be cross-attn, where V and Q/K hdim are different
...
...
@@ -54,6 +69,7 @@ struct naive_attention_fwd_args
int
nhead_ratio_kv
;
// nhead_q / nhead_kv
int
page_size
;
// if paged, the seqlen-kv per each block
int
max_pages_per_seq
;
int
max_kv_tokens
;
// used as stride to access kv scale ptr
};
// this is trait for host API
...
...
@@ -67,14 +83,16 @@ struct naive_attention_fwd_traits
std
::
string
k_layout
;
std
::
string
v_layout
;
std
::
string
o_layout
;
int
variation
;
// sync with naive_attention_variation_enum
int
variation
;
// sync with naive_attention_variation_enum
int
quant_algo
;
// sync with naive_attention_quant_algo
};
// this is trait for kernel template
template
<
naive_attention_variation_enum
variation_
>
template
<
naive_attention_variation_enum
variation_
,
naive_attention_quant_algo
quant_algo_
>
struct
naive_attention_fwd_kernel_traits
{
static
constexpr
naive_attention_variation_enum
variation
=
variation_
;
static
constexpr
naive_attention_quant_algo
quant_algo
=
quant_algo_
;
};
// for simplicity, please do not use const-reference type for the template type
...
...
@@ -83,28 +101,39 @@ template <typename QType,
typename
VType
,
typename
OType
,
typename
AccType
,
typename
KVScaleType
,
naive_attention_layout_enum
QLayout
,
naive_attention_layout_enum
KLayout
,
naive_attention_layout_enum
VLayout
,
naive_attention_layout_enum
OLayout
,
naive_attention_layout_enum
KScaleLayout
,
naive_attention_layout_enum
VScaleLayout
,
typename
Traits
>
struct
naive_attention_fwd_kernel
{
static
constexpr
bool
is_kvcache_i8
=
std
::
is_same_v
<
KType
,
int8_t
>
&&
std
::
is_same_v
<
VType
,
int8_t
>
&&
sizeof
(
QType
)
!=
1
;
std
::
is_same_v
<
KType
,
int8_t
>
&&
std
::
is_same_v
<
VType
,
int8_t
>
;
static
constexpr
bool
is_kvcache_fp8
=
std
::
is_same_v
<
KType
,
fp8_t
>
&&
std
::
is_same_v
<
VType
,
fp8_t
>
;
// kvcache-i8 will have per head scale, we apply this scale to Q/P matrix instead of original
// K/V matrix. This can speed up conversion since Q/P usually is fp16/bf16/fp32
static
constexpr
bool
is_kvcache_i8_forward_quant
=
is_kvcache_i8
;
static
constexpr
int
v_per_token_quant_group_size
=
64
;
// TODO: hardcode
using
KVScaleType
=
float
;
using
SoftmaxType
=
float
;
using
PType
=
VType
;
// src A of gemm2, same type as V
using
SoftmaxType
=
float
;
// always using float to do softmax compute
using
QuantComputeType
=
float
;
// used for quant/dequant scale compute
using
QCompute
=
KType
;
// src A of gemm1, same type as K
using
PType
=
VType
;
// src A of gemm2, same type as V
using
OAccType
=
float
;
// always float, in case int8 FA
using
p_vec_type
=
ext_vector_t
<
PType
,
16
/
sizeof
(
PType
)
>
;
static
constexpr
int
p_vec_elem
=
vector_traits
<
p_vec_type
>::
vector_size
;
// clang-format off
template
<
typename
T_
>
struct
scale_max
{
static
constexpr
float
value
=
1
;
/* dummy code */
};
template
<
>
struct
scale_max
<
int8_t
>
{
static
constexpr
float
value
=
127.0
;
};
template
<
>
struct
scale_max
<
fp8_t
>
{
static
constexpr
float
value
=
240.0
;
};
// clang-format on
__host__
__device__
naive_attention_fwd_kernel
()
{}
template
<
typename
T
,
naive_attention_layout_enum
Layout
>
...
...
@@ -198,24 +227,31 @@ struct naive_attention_fwd_kernel
__device__
void
store
(
T
/*value*/
,
int
/*i_s*/
,
int
/*i_d*/
)
{}
};
template
<
typename
T
>
template
<
typename
T
,
naive_attention_layout_enum
Layout
>
struct
kvscale_addresser
{
int
h
,
d
;
// nhead, hdim
int
s
,
h
,
d
;
//
seqlen(tokens),
nhead, hdim
T
*
base_ptr
;
__device__
kvscale_addresser
(
int
h_
,
int
d_
,
void
*
p_
)
:
h
(
h_
),
d
(
d_
),
base_ptr
(
reinterpret_cast
<
T
*>
(
p_
))
__device__
kvscale_addresser
(
int
s_
,
int
h_
,
int
d_
,
void
*
p_
)
:
s
(
s_
),
h
(
h_
),
d
(
d_
),
base_ptr
(
reinterpret_cast
<
T
*>
(
p_
))
{
}
__device__
int
get_offset
(
int
i_
h
,
int
i_
d
,
int
i_
kv
/*0 or 1*/
)
__device__
int
get_offset
(
int
i_
s
,
int
i_
h
,
int
i_
d
)
{
if
constexpr
(
Layout
==
naive_attention_layout_enum
::
SCALE_HS
)
{
// [nhead, tokens]
(
void
)
i_d
;
return
i_h
*
s
+
i_s
;
}
else
if
constexpr
(
Layout
==
naive_attention_layout_enum
::
DEFAULT
)
{
return
0
;
}
// [h, 2, d]
return
i_h
*
2
*
d
+
i_kv
*
d
+
i_d
;
}
__device__
T
load
(
int
i_h
,
int
i_d
,
int
i_kv
)
{
return
base_ptr
[
get_offset
(
i_h
,
i_d
,
i_kv
)];
// return i_h * 2 * d + i_kv * d + i_d;
}
__device__
T
load
(
int
i_s
,
int
i_h
,
int
i_d
)
{
return
base_ptr
[
get_offset
(
i_s
,
i_h
,
i_d
)];
}
};
__device__
__host__
static
constexpr
int
get_block_size
()
{
return
256
;
}
...
...
@@ -282,12 +318,13 @@ struct naive_attention_fwd_kernel
__device__
void
operator
()(
naive_attention_fwd_args
args
)
{
constexpr
int
wg_size
=
get_block_size
();
__shared__
char
smem
[
wg_size
*
4
*
sizeof
(
float
)];
// should enough
int
i_dv
=
blockIdx
.
x
*
wg_size
+
threadIdx
.
x
;
// index of hdim_v
int
i_sq
=
blockIdx
.
y
;
// index of seqlen_q
int
i_batch
=
blockIdx
.
z
;
// index of batch_q * nhead_q
int
i_bq
=
i_batch
/
args
.
nhead_q
;
// index of batch_q
int
i_hq
=
i_batch
%
args
.
nhead_q
;
// index of nhead_q
__shared__
char
smem
[
wg_size
*
4
*
sizeof
(
float
)];
// should enough
char
*
smem_quant_q
=
smem
+
wg_size
*
2
*
sizeof
(
float
);
// second half, should enough
int
i_dv
=
blockIdx
.
x
*
wg_size
+
threadIdx
.
x
;
// index of hdim_v
int
i_sq
=
blockIdx
.
y
;
// index of seqlen_q
int
i_batch
=
blockIdx
.
z
;
// index of batch_q * nhead_q
int
i_bq
=
i_batch
/
args
.
nhead_q
;
// index of batch_q
int
i_hq
=
i_batch
%
args
.
nhead_q
;
// index of nhead_q
int
i_bk
=
i_bq
/
args
.
batch_ratio_kv
;
int
i_hk
=
i_hq
/
args
.
nhead_ratio_kv
;
...
...
@@ -360,9 +397,10 @@ struct naive_attention_fwd_kernel
auto
f_max
=
[](
auto
x_
,
auto
y_
)
{
return
max
(
x_
,
y_
);
};
auto
f_sum
=
[](
auto
x_
,
auto
y_
)
{
return
x_
+
y_
;
};
auto
f_absmax_f32
=
[](
float
v_0_
,
float
v_1_
)
{
float
rtn
;
asm
volatile
(
"v_max_f32 %0, abs(%1), abs(%2)"
:
"=v"
(
rtn
)
:
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
// float rtn;
// asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
// return rtn;
return
max
(
abs
(
v_0_
),
abs
(
v_1_
));
};
int
seqlen_kv
=
[
&
]()
{
...
...
@@ -378,45 +416,82 @@ struct naive_attention_fwd_kernel
SoftmaxType
row_max
=
-
numeric
<
SoftmaxType
>::
infinity
();
SoftmaxType
l
{
0
};
AccType
o_acc
=
{
0
};
// AccType o_acc = {0};
OAccType
o_acc
=
{
0
};
int
sk_loops
=
(
seqlen_kv
+
wg_size
-
1
)
/
wg_size
;
float
qf_scale
=
.0
f
;
kvscale_addresser
<
KVScaleType
>
kvscale_addr
{
args
.
nhead_kv
,
args
.
hdim
,
args
.
kvscale_ptr
};
int
sk_loops
=
(
seqlen_kv
+
wg_size
-
1
)
/
wg_size
;
QuantComputeType
q_dequant_scale
=
.0
f
;
kvscale_addresser
<
KVScaleType
,
KScaleLayout
>
kscale_addr
{
args
.
max_kv_tokens
,
args
.
nhead_kv
,
args
.
hdim
,
args
.
kscale_ptr
};
kvscale_addresser
<
KVScaleType
,
VScaleLayout
>
vscale_addr
{
args
.
max_kv_tokens
,
args
.
nhead_kv
,
args
.
hdim_v
,
args
.
vscale_ptr
};
if
constexpr
(
is_kvcache_i8_forward_quant
)
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
)
{
// AccType is i32 now, seqlen_q = 1, hdim up to 256
float
q
=
0
;
float
k_s
=
0
;
AccType
q
=
0
;
AccType
k_s
=
0
;
if
(
static_cast
<
int
>
(
threadIdx
.
x
)
<
args
.
hdim
)
{
q
=
type_convert
<
float
>
(
q_addr
.
load
(
0
,
threadIdx
.
x
));
k_s
=
type_convert
<
float
>
(
k
v
scale_addr
.
load
(
i_hk
,
threadIdx
.
x
,
0
));
q
=
type_convert
<
AccType
>
(
q_addr
.
load
(
0
,
threadIdx
.
x
));
k_s
=
type_convert
<
AccType
>
(
kscale_addr
.
load
(
i_hk
,
threadIdx
.
x
,
0
));
}
// 1) we apply the k scale to q
float
q_forwarded
=
q
*
k_s
;
AccType
q_forwarded
=
q
*
k_s
;
// 2) apply smooth-quant
// find absmax
float
qf_max
=
wave_reduce
(
q_forwarded
,
f_absmax_f32
);
qf_max
=
cross_wave_reduce
(
qf_max
,
f_absmax_f32
,
reinterpret_cast
<
float
*>
(
smem
));
AccType
qf_max
=
wave_reduce
(
q_forwarded
,
f_absmax_f32
);
qf_max
=
cross_wave_reduce
(
qf_max
,
f_absmax_f32
,
reinterpret_cast
<
AccType
*>
(
smem
));
// per-token scale
q
f
_scale
=
qf_max
/
127.0
;
q
_dequant
_scale
=
type_convert
<
QuantComputeType
>
(
qf_max
)
/
scale_max
<
QCompute
>::
value
;
// devide by scale
q
=
q
/
q
f
_scale
;
q
=
q
/
q
_dequant
_scale
;
// fp32->i8
int8_t
quantized_q
=
static_cast
<
int8_t
>
(
q
);
QCompute
quantized_q
=
static_cast
<
QCompute
>
(
q
);
__syncthreads
();
reinterpret_cast
<
int8_t
*>
(
smem
)[
threadIdx
.
x
]
=
quantized_q
;
reinterpret_cast
<
QCompute
*>
(
smem
)[
threadIdx
.
x
]
=
quantized_q
;
__syncthreads
();
// after above process, we have 2 data
// 1) int8 q data stored in smem(no need to reload)
// 2) per-token scale qf_scale, to be mul after 1st gemm
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
else
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
)
{
if
(
std
::
is_same_v
<
QType
,
fp16_t
>
||
std
::
is_same_v
<
QType
,
bf16_t
>
)
{
// dyanmic quant q here
float
q
=
0
;
if
(
static_cast
<
int
>
(
threadIdx
.
x
)
<
args
.
hdim
)
{
q
=
type_convert
<
float
>
(
q_addr
.
load
(
i_sq
,
threadIdx
.
x
));
}
// apply smooth-quant
// find absmax
float
q_max
=
wave_reduce
(
q
,
f_absmax_f32
);
q_max
=
cross_wave_reduce
(
q_max
,
f_absmax_f32
,
reinterpret_cast
<
float
*>
(
smem
));
// per-token scale
q_dequant_scale
=
type_convert
<
QuantComputeType
>
(
q_max
)
/
scale_max
<
QCompute
>::
value
;
// devide by scale
q
=
q
/
q_dequant_scale
;
QCompute
quantized_q
=
type_convert
<
QCompute
>
(
q
);
__syncthreads
();
reinterpret_cast
<
QCompute
*>
(
smem_quant_q
)[
threadIdx
.
x
]
=
quantized_q
;
__syncthreads
();
// after above process, we have 2 data
// 1) fp8 q data stored in smem(no need to reload from global)
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
}
for
(
int
i_loop1
=
0
;
i_loop1
<
sk_loops
;
i_loop1
++
)
...
...
@@ -429,33 +504,41 @@ struct naive_attention_fwd_kernel
AccType
s_acc
{
0
};
// clear for every loop
for
(
auto
i_dq
=
0
;
i_dq
<
args
.
hdim
;
i_dq
++
)
{
if
constexpr
(
is_kvcache_i8_forward_quant
)
{
int8_t
q
=
reinterpret_cast
<
int8_t
*>
(
smem
)[
i_dq
];
auto
k
=
k_addr
.
load
(
i_sk
,
i_dq
);
s_acc
+=
type_convert
<
AccType
>
(
q
)
*
type_convert
<
AccType
>
(
k
);
}
else
{
auto
q
=
q_addr
.
load
(
i_sq
,
i_dq
);
// q will have duplicate load
auto
k
=
k_addr
.
load
(
i_sk
,
i_dq
);
auto
q
=
[
&
]()
{
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
||
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
)
{
return
reinterpret_cast
<
QCompute
*>
(
smem_quant_q
)[
i_dq
];
}
else
return
q_addr
.
load
(
i_sq
,
i_dq
);
// q will have duplicate load
}();
auto
k
=
[
&
]()
{
return
k_addr
.
load
(
i_sk
,
i_dq
);
}();
s_acc
+=
type_convert
<
AccType
>
(
q
)
*
type_convert
<
AccType
>
(
k
);
}
s_acc
+=
type_convert
<
AccType
>
(
q
)
*
type_convert
<
AccType
>
(
k
);
}
// scale
s_softmax
=
type_convert
<
SoftmaxType
>
(
s_acc
);
s_softmax
*=
type_convert
<
SoftmaxType
>
(
args
.
scale_s
*
ck_tile
::
log2e_v
<
SoftmaxType
>
);
if
constexpr
(
is_kvcache_i8_forward_quant
)
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
)
{
s_softmax
*=
q_dequant_scale
;
// post scale the per-token factor
}
else
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
)
{
s_softmax
*=
qf_scale
;
// post scale the per-token factor
SoftmaxType
k_per_token_scale
=
type_convert
<
SoftmaxType
>
(
kscale_addr
.
load
(
i_sk
,
i_hk
,
0
));
s_softmax
*=
q_dequant_scale
;
s_softmax
*=
k_per_token_scale
;
}
}
// s->p
float
pf
_scale
=
0
.
;
// used for i8 quant
QuantComputeType
p_dequant
_scale
=
1
.
;
{
// softmax, find max
SoftmaxType
old_max
=
row_max
;
...
...
@@ -473,41 +556,69 @@ struct naive_attention_fwd_kernel
// l, pre-scall o_acc
SoftmaxType
tmp
=
__builtin_amdgcn_exp2f
(
old_max
-
row_max
);
l
=
tmp
*
l
+
row_sum
;
o_acc
=
type_convert
<
AccType
>
(
type_convert
<
SoftmaxType
>
(
o_acc
)
*
tmp
);
o_acc
=
type_convert
<
O
AccType
>
(
type_convert
<
SoftmaxType
>
(
o_acc
)
*
tmp
);
// prepare the p_compute into smem, to let every thread read same p_compute and do
// 2nd gemm
if
constexpr
(
is_kvcache_i8_forward_quant
)
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
)
{
float
v_s
=
0
;
QuantComputeType
v_s
=
0
;
if
(
static_cast
<
int
>
(
threadIdx
.
x
)
<
args
.
hdim_v
)
{
v_s
=
type_convert
<
float
>
(
kvscale_addr
.
load
(
i_hk
,
threadIdx
.
x
,
1
));
v_s
=
type_convert
<
QuantComputeType
>
(
vscale_addr
.
load
(
i_hk
,
threadIdx
.
x
,
1
));
}
// 1) we apply the v scale to p
float
p_forwarded
=
p_compute
*
v_s
;
QuantComputeType
p_forwarded
=
p_compute
*
v_s
;
// 2) apply smooth-quant
// find absmax
float
pf_max
=
wave_reduce
(
p_forwarded
,
f_absmax_f32
);
pf_max
=
cross_wave_reduce
(
pf_max
,
f_absmax_f32
,
reinterpret_cast
<
float
*>
(
smem
));
QuantComputeType
pf_max
=
wave_reduce
(
p_forwarded
,
f_absmax_f32
);
pf_max
=
cross_wave_reduce
(
pf_max
,
f_absmax_f32
,
reinterpret_cast
<
QuantComputeType
*>
(
smem
));
// per-token scale
p
f
_scale
=
pf_max
/
127.0
;
p
_dequant
_scale
=
pf_max
/
scale_max
<
PType
>::
value
;
//
127.0;
// devide by scale
p_compute
=
p_compute
/
p
f
_scale
;
p_compute
=
p_compute
/
p
_dequant
_scale
;
// fp32->i8
int8_t
quantized_p
=
static_cast
<
int8_t
>
(
p_compute
);
PType
quantized_p
=
static_cast
<
PType
>
(
p_compute
);
__syncthreads
();
reinterpret_cast
<
int8_t
*>
(
smem
)[
threadIdx
.
x
]
=
quantized_p
;
reinterpret_cast
<
PType
*>
(
smem
)[
threadIdx
.
x
]
=
quantized_p
;
__syncthreads
();
// after above process, we have 2 data
// 1) int8 p data stored in smem(no need to reload)
// 2) per-token scale pf_scale, to be mul after 2nd gemm
// 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
}
else
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
)
{
// forward apply the v scale to p_compute, this is compute friendly
auto
v_scale
=
type_convert
<
QuantComputeType
>
(
vscale_addr
.
load
(
i_sk
,
i_hk
,
0
));
p_compute
*=
v_scale
;
// smooth-quant
// find absmax
QuantComputeType
p_max
=
wave_reduce
(
p_compute
,
f_absmax_f32
);
p_max
=
cross_wave_reduce
(
p_max
,
f_absmax_f32
,
reinterpret_cast
<
QuantComputeType
*>
(
smem
));
// per-token scale
p_dequant_scale
=
p_max
/
scale_max
<
PType
>::
value
;
// 240.0;
// devide by scale
p_compute
=
p_compute
/
p_dequant_scale
;
// fp32->i8
PType
quantized_p
=
type_convert
<
PType
>
(
p_compute
);
__syncthreads
();
reinterpret_cast
<
PType
*>
(
smem
)[
threadIdx
.
x
]
=
quantized_p
;
__syncthreads
();
// after above process, we have 2 data
// 1) fp8_t p data stored in smem(no need to reload)
// 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
}
else
{
...
...
@@ -531,29 +642,45 @@ struct naive_attention_fwd_kernel
int
sv_offset
=
i_loop2
*
p_vec_elem
+
i_j
;
int
i_sv
=
sk_start
+
sv_offset
;
VType
v
=
0
.
f
;
VType
v
=
0
;
if
(
i_dv
<
args
.
hdim_v
&&
i_sv
<
seqlen_kv
)
{
v
=
v_addr
.
load
(
i_sv
,
i_dv
);
}
o_acc_local
+=
type_convert
<
AccType
>
(
p_vec
[
i_j
])
*
type_convert
<
AccType
>
(
v
);
AccType
v_compute
=
[
&
]()
{
return
type_convert
<
AccType
>
(
v
);
}();
o_acc_local
+=
type_convert
<
AccType
>
(
p_vec
[
i_j
])
*
v_compute
;
}
}
if
constexpr
(
is_kvcache_i8_forward_quant
)
{
// apply pr scale to local acc
o_acc_local
=
type_convert
<
AccType
>
(
type_convert
<
float
>
(
o_acc_local
)
*
pf_scale
);
}
o_acc
+=
o_acc_local
;
OAccType
post_scale_o_acc_local
=
[
&
]()
{
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERHEAD
)
{
// apply pr scale to local acc
return
type_convert
<
OAccType
>
(
type_convert
<
QuantComputeType
>
(
o_acc_local
)
*
p_dequant_scale
);
}
else
if
constexpr
(
Traits
::
quant_algo
==
naive_attention_quant_algo
::
KV_8BIT_PERTOKEN
)
{
// apply pr scale to local acc
return
type_convert
<
OAccType
>
(
type_convert
<
QuantComputeType
>
(
o_acc_local
)
*
p_dequant_scale
);
}
else
{
return
type_convert
<
OAccType
>
(
o_acc_local
);
}
}();
o_acc
+=
post_scale_o_acc_local
;
}
}
// post scale o_acc
{
SoftmaxType
tmp
=
l
==
0.
f
?
0.
f
:
1.
f
/
l
;
// in case masking
o_acc
=
type_convert
<
AccType
>
(
type_convert
<
SoftmaxType
>
(
o_acc
)
*
tmp
);
o_acc
=
type_convert
<
O
AccType
>
(
type_convert
<
SoftmaxType
>
(
o_acc
)
*
tmp
);
}
// store O
...
...
@@ -564,18 +691,21 @@ struct naive_attention_fwd_kernel
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \
{ \
using ktraits_ =
\
naive_attention_fwd_kernel_traits<
static_cast<naive_attention_variation_enum>( \
variation_)>;
\
using ktraits_ =
naive_attention_fwd_kernel_traits<
\
static_cast<naive_attention_variation_enum>(
variation_),
\
static_cast<naive_attention_quant_algo>(quant_algo_)>;
\
using k_ = naive_attention_fwd_kernel<q_type_, \
k_type_, \
v_type_, \
o_type_, \
acc_type_, \
kvscale_type_, \
q_layout_, \
k_layout_, \
v_layout_, \
o_layout_, \
k_scale_layout_, \
v_scale_layout_, \
ktraits_>; \
dim3 grids = k_::get_grid_size(a); \
r = ck_tile::launch_kernel(s, \
...
...
@@ -586,31 +716,37 @@ struct naive_attention_fwd_kernel
if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
t.o_layout == "bshd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
constexpr int variation_ = 0; \
constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \
else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \
t.v_layout == "bhsd" && t.o_layout == "bhsd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr int variation_ = 0; \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \
else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \
t.v_layout == "phds" && t.o_layout == "bhsd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr int variation_ = 2; \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
constexpr int variation_ = 2; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
}
...
...
@@ -621,40 +757,64 @@ CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
{
float
r
=
-
1
;
// TODO: do not explicitly create too much instance!
if
(
t
.
q_type
==
"fp16"
&&
t
.
k_type
==
"fp16"
&&
t
.
v_type
==
"fp16"
&&
t
.
o_type
==
"fp16"
)
if
(
t
.
q_type
==
"fp16"
&&
t
.
k_type
==
"fp16"
&&
t
.
v_type
==
"fp16"
&&
t
.
o_type
==
"fp16"
&&
t
.
quant_algo
==
0
)
{
using
q_type_
=
fp16_t
;
using
k_type_
=
fp16_t
;
using
v_type_
=
fp16_t
;
using
o_type_
=
fp16_t
;
using
acc_type_
=
float
;
using
kvscale_type_
=
float
;
constexpr
int
quant_algo_
=
0
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
}
else
if
(
t
.
q_type
==
"bf16"
&&
t
.
k_type
==
"bf16"
&&
t
.
v_type
==
"bf16"
&&
t
.
o_type
==
"bf16"
&&
t
.
quant_algo
==
0
)
{
using
q_type_
=
fp16_t
;
using
k_type_
=
fp16_t
;
using
v_type_
=
fp16_t
;
using
o_type_
=
fp16_t
;
using
acc_type_
=
float
;
using
q_type_
=
bf16_t
;
using
k_type_
=
bf16_t
;
using
v_type_
=
bf16_t
;
using
o_type_
=
bf16_t
;
using
acc_type_
=
float
;
using
kvscale_type_
=
float
;
constexpr
int
quant_algo_
=
0
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
}
else
if
(
t
.
q_type
==
"bf16"
&&
t
.
k_type
==
"bf16"
&&
t
.
v_type
==
"bf16"
&&
t
.
o_type
==
"bf16"
)
else
if
(
t
.
q_type
==
"bf16"
&&
t
.
k_type
==
"fp8"
&&
t
.
v_type
==
"fp8"
&&
t
.
o_type
==
"bf16"
&&
t
.
quant_algo
==
2
)
{
using
q_type_
=
bf16_t
;
using
k_type_
=
bf16_t
;
using
v_type_
=
bf16_t
;
using
o_type_
=
bf16_t
;
using
acc_type_
=
float
;
using
q_type_
=
bf16_t
;
using
k_type_
=
fp8_t
;
using
v_type_
=
fp8_t
;
using
o_type_
=
bf16_t
;
using
acc_type_
=
float
;
// NOTE!
using
kvscale_type_
=
float
;
constexpr
int
quant_algo_
=
2
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
}
else
if
(
t
.
q_type
==
"bf16"
&&
t
.
k_type
==
"int8"
&&
t
.
v_type
==
"int8"
&&
t
.
o_type
==
"bf16"
)
else
if
(
t
.
q_type
==
"fp16"
&&
t
.
k_type
==
"fp8"
&&
t
.
v_type
==
"fp8"
&&
t
.
o_type
==
"fp16"
&&
t
.
quant_algo
==
2
)
{
using
q_type_
=
bf16_t
;
using
k_type_
=
int8_t
;
using
v_type_
=
int8_t
;
using
o_type_
=
bf16_t
;
using
acc_type_
=
int32_t
;
// NOTE!
using
q_type_
=
fp16_t
;
using
k_type_
=
fp8_t
;
using
v_type_
=
fp8_t
;
using
o_type_
=
fp16_t
;
using
acc_type_
=
float
;
// NOTE!
using
kvscale_type_
=
float
;
constexpr
int
quant_algo_
=
2
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
}
else
if
(
t
.
q_type
==
"fp16"
&&
t
.
k_type
==
"int8"
&&
t
.
v_type
==
"int8"
&&
t
.
o_type
==
"fp16"
)
else
if
(
t
.
q_type
==
"bf16"
&&
t
.
k_type
==
"int8"
&&
t
.
v_type
==
"int8"
&&
t
.
o_type
==
"bf16"
&&
t
.
quant_algo
==
2
)
{
using
q_type_
=
fp16_t
;
using
k_type_
=
int8_t
;
using
v_type_
=
int8_t
;
using
o_type_
=
fp16_t
;
using
acc_type_
=
int32_t
;
// NOTE!
using
q_type_
=
bf16_t
;
using
k_type_
=
int8_t
;
using
v_type_
=
int8_t
;
using
o_type_
=
bf16_t
;
using
acc_type_
=
int32_t
;
// NOTE!
using
kvscale_type_
=
float
;
constexpr
int
quant_algo_
=
2
;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
();
}
return
r
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp
0 → 100644
View file @
67ab3896
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <memory>
#include <vector>
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2BScale
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
1
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
typename
BScaleDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
index_t
ScaleBlockK
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmV2BScale
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
BScaleDataType
,
CDataType
,
1
,
ScaleBlockK
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGemmV2BScale
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
BScaleDataType
,
CDataType
,
1
,
ScaleBlockK
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
pk_i4_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp
View file @
67ab3896
...
...
@@ -238,6 +238,403 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpaddin
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
...
...
@@ -527,6 +924,109 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances
(
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
100644 → 100755
View file @
67ab3896
...
...
@@ -183,6 +183,10 @@ FOREACH(subdir_path ${dir_list})
message
(
"bf8 instance found!"
)
set
(
add_inst 1
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"_bf16"
OR
"
${
cmake_instance
}
"
MATCHES
"_b16"
)
AND DTYPES MATCHES
"bf16"
)
message
(
"bf16 instance found!"
)
set
(
add_inst 1
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"_fp16"
OR
"
${
cmake_instance
}
"
MATCHES
"_f16"
)
AND DTYPES MATCHES
"fp16"
)
message
(
"fp16 instance found!"
)
set
(
add_inst 1
)
...
...
@@ -195,10 +199,6 @@ FOREACH(subdir_path ${dir_list})
message
(
"fp64 instance found!"
)
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"_bf16"
AND DTYPES MATCHES
"bf16"
)
message
(
"bf16 instance found!"
)
set
(
add_inst 1
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"_int8"
OR
"
${
cmake_instance
}
"
MATCHES
"_i8"
)
AND DTYPES MATCHES
"int8"
)
message
(
"int8 instance found!"
)
set
(
add_inst 1
)
...
...
library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt
0 → 100644
View file @
67ab3896
# ONLY XDL_KERNELS
set
(
GEMM_B_SCALE_INSTANCES
)
list
(
APPEND GEMM_B_SCALE_INSTANCES
device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
)
set_source_files_properties
(
device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
add_instance_library
(
device_gemm_b_scale_instance
${
GEMM_B_SCALE_INSTANCES
}
)
\ No newline at end of file
library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp
0 → 100644
View file @
67ab3896
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
I4
=
pk_i4_t
;
using
F16
=
half_t
;
using
F32
=
float
;
using
Row
=
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
index_t
...
Is
>
using
S
=
Sequence
<
Is
...
>
;
using
PassThrough
=
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmKPadding
=
GemmSpecialization
::
KPadding
;
static
constexpr
auto
GemmMNPadding
=
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
Intrawave
=
BlockGemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
Interwave
=
BlockGemmPipelineScheduler
::
Interwave
;
#if 0
template <GemmSpecialization GemmSpec>
using device_gemm_xdl_b_scale_f16_i4_f16_mk_nk_mn_comp_instances = std::tuple<
#endif
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//Compute friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
128
,
128
,
128
,
8
,
32
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
128
,
128
,
64
,
8
,
32
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
128
,
128
,
128
,
8
,
32
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
128
,
128
,
64
,
8
,
32
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
128
,
128
,
64
,
8
,
32
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
//Latency friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
32
,
16
,
128
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
1
,
128
,
16
,
16
,
128
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
1
,
128
,
16
,
16
,
128
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
16
,
32
,
128
,
8
,
32
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
// Memory friendly v3
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
128
,
32
,
128
,
8
,
32
,
32
,
32
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
128
,
16
,
128
,
8
,
16
,
16
,
16
,
4
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
64
,
32
,
128
,
8
,
32
,
32
,
32
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
64
,
16
,
128
,
8
,
16
,
16
,
16
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
32
,
16
,
128
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
1
,
128
,
16
,
16
,
128
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
1
,
128
,
16
,
16
,
128
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
16
,
32
,
128
,
8
,
32
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
16
,
64
,
128
,
8
,
32
,
16
,
16
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
32
,
64
,
128
,
8
,
32
,
32
,
32
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
16
,
128
,
128
,
8
,
32
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
32
,
128
,
128
,
8
,
32
,
32
,
32
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
16
,
256
,
128
,
8
,
32
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
32
,
256
,
128
,
8
,
32
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
// Memory friendly v4
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
64
,
32
,
128
,
8
,
32
,
32
,
32
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
64
,
16
,
128
,
8
,
16
,
16
,
16
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
32
,
16
,
128
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
1
,
128
,
16
,
16
,
128
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
1
,
128
,
16
,
16
,
128
,
8
,
16
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
16
,
32
,
128
,
8
,
32
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
16
,
64
,
128
,
8
,
32
,
16
,
16
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
32
,
64
,
128
,
8
,
32
,
32
,
32
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
16
,
128
,
128
,
8
,
32
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
1
,
128
,
32
,
128
,
128
,
8
,
32
,
32
,
32
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
16
,
256
,
128
,
8
,
32
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
32
,
256
,
128
,
8
,
32
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v4
,
half_t
,
half_t
,
false
,
false
>
,
//new Compute friendly kernel
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
128
,
128
,
64
,
8
,
32
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
128
,
128
,
64
,
8
,
32
,
32
,
32
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
,
//new Memory friendly kernel
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
1
,
128
,
16
,
64
,
256
,
8
,
32
,
16
,
16
,
1
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v3
,
half_t
,
half_t
,
false
,
false
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
0 → 100644
View file @
67ab3896
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2BScale
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
F16
,
1
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt
View file @
67ab3896
...
...
@@ -64,6 +64,43 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
)
device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
)
add_instance_library
(
device_gemm_universal_streamk_instance
${
GEMM_UNIVERSAL_STREAMK_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp
0 → 100755
View file @
67ab3896
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
bhalf_t
;
using
F32
=
float
;
using
Row
=
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
index_t
...
Is
>
using
S
=
Sequence
<
Is
...
>
;
using
PassThrough
=
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmKPadding
=
GemmSpecialization
::
KPadding
;
static
constexpr
auto
GemmMPadding
=
GemmSpecialization
::
MPadding
;
static
constexpr
auto
GemmMNPadding
=
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmMKPadding
=
GemmSpecialization
::
MKPadding
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
Intrawave
=
BlockGemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
Interwave
=
BlockGemmPipelineScheduler
::
Interwave
;
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances
=
std
::
tuple
<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
// Can we support this kind of odd case? 224(256) = 28*8 + (4*8)
//DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
// clang-format on
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_instances
=
std
::
tuple
<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
// Memory friendly
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
8
,
2
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
2
,
2
,
16
,
16
,
4
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
64
,
8
,
4
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
16
,
64
,
4
,
4
,
16
,
16
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
64
,
4
,
4
,
16
,
16
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
64
,
4
,
4
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
2
,
4
,
16
,
16
,
1
,
4
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
2
,
2
,
16
,
16
,
1
,
4
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp
0 → 100755
View file @
67ab3896
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instance.cpp
0 → 100755
View file @
67ab3896
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances
<
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instance.cpp
0 → 100755
View file @
67ab3896
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances
<
GemmMNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instance.cpp
0 → 100755
View file @
67ab3896
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances
<
GemmMNPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instance.cpp
0 → 100755
View file @
67ab3896
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment