Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f221c2b0
Unverified
Commit
f221c2b0
authored
Oct 24, 2024
by
Illia Silin
Committed by
GitHub
Oct 24, 2024
Browse files
Merge pull request #203 from ROCm/merge_from_public
Merge from public
parents
140d2fa6
e1cd4121
Changes
134
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1354 additions
and
641 deletions
+1354
-641
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+5
-2
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+148
-351
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp
.../ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp
+78
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
+99
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+119
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+40
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+160
-0
include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
...e/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
+0
-35
include/ck_tile/ops/reduce/block/block_reduce.hpp
include/ck_tile/ops/reduce/block/block_reduce.hpp
+1
-1
include/ck_tile/ops/welford.hpp
include/ck_tile/ops/welford.hpp
+2
-1
include/ck_tile/ops/welford/block/block_welford.hpp
include/ck_tile/ops/welford/block/block_welford.hpp
+362
-0
include/ck_tile/ops/welford/block/block_welford_problem.hpp
include/ck_tile/ops/welford/block/block_welford_problem.hpp
+18
-0
include/ck_tile/ops/welford/thread/thread_welford.hpp
include/ck_tile/ops/welford/thread/thread_welford.hpp
+24
-89
include/ck_tile/ops/welford/warp/warp_welford.hpp
include/ck_tile/ops/welford/warp/warp_welford.hpp
+0
-154
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+68
-3
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
...wd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
+37
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+52
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
..._instance/gpu/grouped_convolution_backward_weight_xdl.inc
+118
-0
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp
...f16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp
+8
-1
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp
...f16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp
+15
-2
No files found.
include/ck_tile/ops/layernorm2d.hpp
View file @
f221c2b0
...
@@ -4,6 +4,9 @@
...
@@ -4,6 +4,9 @@
#pragma once
#pragma once
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
f221c2b0
...
@@ -5,37 +5,57 @@
...
@@ -5,37 +5,57 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
// TODO: Extract some type to wrapper class
// host side args
template
<
typename
Problem_
>
struct
Layernorm2dFwdHostArgs
struct
Layernorm2dFwd
{
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
const
void
*
p_x
;
const
void
*
p_gamma
;
const
void
*
p_beta
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
void
*
p_y
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
void
*
p_mean
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
void
*
p_invStd
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
float
epsilon
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
!
std
::
is_same_v
<
MeanDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveInvStd
=
!
std
::
is_same_v
<
InvStdDataType
,
ck_tile
::
null_type
>
;
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
index_t
m
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
index_t
n
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
index_t
stride
;
// row_stride
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
}
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
// TODO: Extract some type to wrapper class
static
constexpr
ck_tile
::
index_t
kNPerThread
=
Problem
::
BlockShape
::
kNPerThread
;
template
<
typename
Pipeline_
>
struct
Layernorm2dFwd
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
null_type
>
;
static
constexpr
bool
kSaveMeanInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
index_t
Repeat_N
=
Problem
::
BlockShape
::
Repeat_N
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
...
@@ -52,400 +72,177 @@ struct Layernorm2dFwd
...
@@ -52,400 +72,177 @@ struct Layernorm2dFwd
float
epsilon
;
float
epsilon
;
ck_tile
::
index_t
M
;
index_t
m
;
ck_tile
::
index_t
N
;
index_t
n
;
index_t
stride
;
// row_stride
};
};
using
Hargs
=
Layernorm2dFwdHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
void
*
p_x
,
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_mean
,
void
*
p_invStd
,
float
epsilon
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
)
{
{
return
Kargs
{
p_x
,
p_gamma
,
p_beta
,
p_y
,
p_mean
,
p_invStd
,
epsilon
,
M
,
N
};
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_gamma
,
hargs
.
p_beta
,
hargs
.
p_y
,
hargs
.
p_mean
,
hargs
.
p_invStd
,
hargs
.
epsilon
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
M
)
{
return
M
/
kMPerBlock
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
kBlockSize
;
}
CK_TILE_DEVICE
static
constexpr
auto
MakeXBlockTileDistribution
()
{
{
using
S
=
typename
Problem
::
BlockShape
;
return
(
hargs
.
m
+
Block_M
-
1
)
/
Block_M
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
kMWarpPerBlock
,
S
::
kMThreadPerWarp
,
S
::
kMPerThread
>
,
sequence
<
S
::
kNWarpPerBlock
,
S
::
kNThreadPerWarp
,
S
::
kNPerThread
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
2
>>
{});
}
}
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
kMWarpPerBlock
,
S
::
kMThreadPerWarp
>
,
tuple
<
sequence
<
S
::
kNWarpPerBlock
,
S
::
kNThreadPerWarp
,
S
::
kNPerThread
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
1
>
,
sequence
<
2
>>
{});
}
CK_TILE_DEVICE
static
int
GetWelfordMaxCount
(
int
N
)
{
constexpr
ck_tile
::
index_t
kNThreadPerBlock
=
kNPerBlock
/
kNPerThread
;
int
thread_id_n
=
get_thread_id
()
%
kNThreadPerBlock
;
int
max_count
=
__builtin_amdgcn_readfirstlane
(
N
<
kNPerBlock
?
0
:
kNPerThread
*
(
N
/
kNPerBlock
));
int
n_per_block_tail_loop
=
__builtin_amdgcn_readfirstlane
(
N
-
max_count
*
kNThreadPerBlock
);
if
(
n_per_block_tail_loop
>
0
)
{
int
thread_max_n
=
(
thread_id_n
+
1
)
*
kNPerThread
;
int
delta
=
thread_max_n
-
n_per_block_tail_loop
;
delta
=
clamp
(
thread_max_n
-
n_per_block_tail_loop
,
0
,
kNPerThread
);
max_count
+=
kNPerThread
-
delta
;
}
return
max_count
;
}
template
<
typename
DistributedTensor
>
// clang-format off
CK_TILE_DEVICE
static
auto
InvSqrt
(
const
DistributedTensor
&
in_dstr_tensor
,
template
<
typename
T
>
struct
t2s
;
const
ComputeDataType
epsilon
)
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
{
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
// TODO: Investigate fast inverse square root algorithm with epsilon
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
()
;
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
}
;
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
DistributedTensor
out_dstr_tensor
;
// clang-format on
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
// in byte
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
out_dstr_tensor
(
i_idx
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck_tile
::
sqrt
(
in_dstr_tensor
[
i_idx
]
+
epsilon
);
});
return
out_dstr_tensor
;
}
template
<
typename
XBlockWindow
,
CK_TILE_HOST
static
std
::
string
GetName
()
typename
GammaBlockWindow
,
typename
BetaBlockWindow
,
typename
YBlockWindow
,
typename
MeanBlockWindow
,
typename
InvStdBlockWindow
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
TwoPassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
{
// TODO - Optimize tail loop to reduce move_tile_window()
// clang-format off
index_t
num_n_tile_iteration
=
using
S_
=
typename
Problem
::
BlockShape
;
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
N
,
kNPerBlock
));
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
if
(
kPadN
)
n
+=
"_pn"
;
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
if
(
kSaveMeanInvStd
)
n
+=
"_mv"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
return
n
;
}();
auto
mean_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
#define _SS_ std::string
auto
var_compute_block_tensor
=
#define _TS_ std::to_string
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
clear_tile
(
mean_compute_block_tensor
);
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
clear_tile
(
var_compute_block_tensor
);
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
#undef _TS_
{
// clang-format on
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
move_tile_window
(
x_block_window
,
{
0
,
kNPerBlock
});
}
// TODO: support cross warp Welford
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_block_window
,
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
N
%
kNPerBlock
==
0
?
N
-
kNPerBlock
:
N
-
N
%
kNPerBlock
;
move_tile_window
(
x_block_window
,
{
0
,
-
kNPerBlock
});
move_tile_window
(
gamma_block_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_block_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_block_window
,
{
0
,
stride_to_right_most_window
});
// Normalization
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
auto
y_block_tensor
=
make_static_distributed_tensor
<
YDataType
>
(
x_block_tensor
.
get_tile_distribution
());
sweep_tile_span
(
x_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx1
);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_block_tensor
[
j_idx
]);
const
auto
beta
=
type_convert
<
ComputeDataType
>
(
beta_block_tensor
[
j_idx
]);
sweep_tile_span
(
x_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
mean
=
mean_compute_block_tensor
[
i_idx
];
const
auto
inv_std
=
inv_std_compute_block_tensor
[
i_idx
];
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_block_tensor
[
i_j_idx
]);
auto
y
=
(
x
-
mean
)
*
inv_std
*
gamma
+
beta
;
y_block_tensor
(
i_j_idx
)
=
type_convert
<
YDataType
>
(
y
);
});
});
store_tile
(
y_block_window
,
y_block_tensor
);
move_tile_window
(
x_block_window
,
{
0
,
-
kNPerBlock
});
move_tile_window
(
gamma_block_window
,
{
-
kNPerBlock
});
move_tile_window
(
beta_block_window
,
{
-
kNPerBlock
});
move_tile_window
(
y_block_window
,
{
0
,
-
kNPerBlock
});
}
}
template
<
typename
XBlockWindow
,
typename
GammaBlockWindow
,
typename
BetaBlockWindow
,
typename
YBlockWindow
,
typename
MeanBlockWindow
,
typename
InvStdBlockWindow
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
OnePassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
auto
var_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
clear_tile
(
mean_compute_block_tensor
);
clear_tile
(
var_compute_block_tensor
);
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
// TODO: support cross warp Welford
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_block_window
,
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
// normalize
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
auto
y_block_tensor
=
make_static_distributed_tensor
<
YDataType
>
(
x_block_tensor
.
get_tile_distribution
());
sweep_tile_span
(
x_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx1
);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_block_tensor
[
j_idx
]);
const
auto
beta
=
type_convert
<
ComputeDataType
>
(
beta_block_tensor
[
j_idx
]);
sweep_tile_span
(
x_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
mean
=
mean_compute_block_tensor
[
i_idx
];
const
auto
inv_std
=
inv_std_compute_block_tensor
[
i_idx
];
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_block_tensor
[
i_j_idx
]);
auto
y
=
(
x
-
mean
)
*
inv_std
*
gamma
+
beta
;
y_block_tensor
(
i_j_idx
)
=
type_convert
<
YDataType
>
(
y
);
});
});
store_tile
(
y_block_window
,
y_block_tensor
);
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
const
auto
x_m_n
=
[
&
]()
{
const
auto
iM
=
get_block_id
()
*
Block_M
;
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
x_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
N
,
1
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
kNPerThread
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
x_dram_naive
,
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
// check the max count dynamically
sequence
<
kPadM
,
kPadN
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
}();
const
auto
gamma_
n
=
[
&
]()
{
const
auto
gamma_
window
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
make_tuple
(
kargs
.
N
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
const
auto
tmp2_
=
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}();
}();
const
auto
beta_
n
=
[
&
]()
{
const
auto
beta_
window
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BetaDataType
*>
(
kargs
.
p_beta
),
static_cast
<
const
BetaDataType
*>
(
kargs
.
p_beta
),
make_tuple
(
kargs
.
N
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
const
auto
tmp2_
=
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
0
});
}();
}();
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
auto
y_window
=
[
&
]()
{
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
const
auto
y_m_n
=
[
&
]()
{
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
N
,
1
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
kNPerThread
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
y_dram_naive
,
auto
tmp2_
=
pad_tensor_view
(
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
}();
auto
y_block_window
=
make_tile_window
(
auto
mean_window
=
[
&
]()
{
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
auto
mean_block_window
=
[
&
]()
{
if
constexpr
(
kSaveMean
)
if
constexpr
(
kSaveMean
)
{
{
const
auto
mean_m
=
[
&
]()
{
const
auto
mean_m
=
[
&
]()
{
const
auto
mean_dram_naive
=
const
auto
mean_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
MeanDataType
*>
(
kargs
.
p_mean
),
static_cast
<
MeanDataType
*>
(
kargs
.
p_mean
),
make_tuple
(
kargs
.
M
),
make_tuple
(
kargs
.
m
),
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
return
pad_tensor_view
(
mean_dram_naive
,
make_tuple
(
number
<
kMPer
Block
>
{}),
sequence
<
kPadM
>
{});
mean_dram_naive
,
make_tuple
(
number
<
Block
_M
>
{}),
sequence
<
kPadM
>
{});
}();
}();
return
make_tile_window
(
mean_m
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
return
make_tile_window
(
mean_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
}
else
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPer
Block
>
{}));
return
make_null_tile_window
(
make_tuple
(
number
<
Block
_M
>
{}));
}();
}();
auto
inv_std_
block_
window
=
[
&
]()
{
auto
inv_std_window
=
[
&
]()
{
if
constexpr
(
kSaveInvStd
)
if
constexpr
(
kSaveInvStd
)
{
{
const
auto
inv_std_m
=
[
&
]()
{
const
auto
inv_std_m
=
[
&
]()
{
const
auto
inv_std_dram_naive
=
const
auto
inv_std_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
InvStdDataType
*>
(
kargs
.
p_invStd
),
static_cast
<
InvStdDataType
*>
(
kargs
.
p_invStd
),
make_tuple
(
kargs
.
M
),
make_tuple
(
kargs
.
m
),
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
return
pad_tensor_view
(
inv_std_dram_naive
,
make_tuple
(
number
<
kMPer
Block
>
{}),
sequence
<
kPadM
>
{});
inv_std_dram_naive
,
make_tuple
(
number
<
Block
_M
>
{}),
sequence
<
kPadM
>
{});
}();
}();
return
make_tile_window
(
inv_std_m
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
return
make_tile_window
(
inv_std_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
}
else
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPer
Block
>
{}));
return
make_null_tile_window
(
make_tuple
(
number
<
Block
_M
>
{}));
}();
}();
if
(
kargs
.
N
<=
kNPerBlock
)
__shared__
char
smem
[
GetSmemSize
()];
OnePassLayernorm2dFwd
(
x_block_window
,
gamma_block_window
,
Pipeline
{}(
x_window
,
beta_block_window
,
gamma_window
,
y_block_window
,
beta_window
,
mean_block_window
,
y_window
,
inv_std_block_window
,
mean_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
inv_std_window
,
kargs
.
N
);
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
else
kargs
.
n
,
TwoPassLayernorm2dFwd
(
x_block_window
,
smem
);
gamma_block_window
,
beta_block_window
,
y_block_window
,
mean_block_window
,
inv_std_block_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
N
);
}
}
};
};
...
...
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp
0 → 100644
View file @
f221c2b0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template
<
typename
BlockTile_
,
// block size, seq<M, N>
typename
WarpPerBlock_
,
// num warps along seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
Layernorm2dShape
{
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N
=
BlockTile_
::
at
(
number
<
1
>
{});
// num warps along seq<M, N>, within each block
static
constexpr
index_t
WarpPerBlock_M
=
WarpPerBlock_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
WarpPerBlock_
::
at
(
number
<
1
>
{});
// warp size
static
constexpr
index_t
Warp_M
=
WarpTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile_
::
at
(
number
<
1
>
{});
static_assert
(
Block_M
%
(
WarpPerBlock_M
*
Warp_M
)
==
0
);
static_assert
(
Block_N
%
(
WarpPerBlock_N
*
Warp_N
)
==
0
);
// repeat of each thread along seq<M, N>
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPerBlock_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPerBlock_N
*
Warp_N
);
// vector size along seq<M, N>
static
constexpr
index_t
Vector_M
=
Vector_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Vector_N
=
Vector_
::
at
(
number
<
1
>
{});
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
// num of threads along seq<M, N>, within each warp
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
0 → 100644
View file @
f221c2b0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp"
namespace
ck_tile
{
struct
Layernorm2dFwdPipelineDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeXBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
,
S
::
Vector_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelford
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockWelford
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordSync
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockWelfordSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordCrossWarpSync
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockWelfordCrossWarpSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
using
block_welford
=
BlockWelford
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
XDataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
using
mean_var_block_tile
=
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
return
GetBlockWelfordCrossWarpSync
<
Problem
>
()
.
template
GetSmemSize
<
mean_var_block_tile
>();
}
else
{
return
1
;
// zero size arrays are an extension
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
0 → 100644
View file @
f221c2b0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dFwdPipelineDefaultPolicy
>
struct
Layernorm2dFwdPipelineOnePass
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr"
;
// block per row
else
return
"wpr"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
MeanWindow
,
typename
InvStdWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
beta_window
=
make_tile_window
(
beta_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
x
=
load_tile
(
x_window
);
int
cur_count
=
0
;
int
max_count
=
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
auto
block_welford
=
Policy
::
template
GetBlockWelford
<
Problem
>();
auto
block_welford_sync
=
Policy
::
template
GetBlockWelfordSync
<
Problem
>();
auto
block_welford_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
// compute welford each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
x
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
)
+
epsilon
);
},
var
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_window
,
cast_tile
<
MeanDataType
>
(
mean
));
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
// layernorm computation
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
);
});
store_tile
(
y_window
,
y
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/
block_
layernorm2d_fwd_problem.hpp
→
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_
pipeline_
problem.hpp
View file @
f221c2b0
...
@@ -15,20 +15,26 @@ template <typename XDataType_,
...
@@ -15,20 +15,26 @@ template <typename XDataType_,
typename
MeanDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kPadN_
>
bool
kSaveMeanInvStd_
,
struct
BlockLayernorm2dFwdProblem
bool
kTwoPass_
>
struct
Layernorm2dFwdPipelineProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
0 → 100644
View file @
f221c2b0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dFwdPipelineDefaultPolicy
>
struct
Layernorm2dFwdPipelineTwoPass
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr"
;
// block per row
else
return
"wpr"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
MeanWindow
,
typename
InvStdWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
beta_window
=
make_tile_window
(
beta_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
// Problem::BlockShape
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
row_size
,
Block_N
));
// total number of count assume current iter have no pad(only last iter has pad)
constexpr
index_t
count_per_iter
=
Problem
::
BlockShape
::
Repeat_N
*
Problem
::
BlockShape
::
Vector_N
;
const
index_t
last_iter_n
=
row_size
-
(
num_n_tile_iteration
-
1
)
*
Block_N
;
int
cur_count
=
0
;
int
max_count
=
(
num_n_tile_iteration
-
1
)
*
count_per_iter
+
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
last_iter_n
);
auto
block_welford
=
Policy
::
template
GetBlockWelford
<
Problem
>();
auto
block_welford_sync
=
Policy
::
template
GetBlockWelfordSync
<
Problem
>();
auto
block_welford_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
load_tile
(
x_window
));
auto
mean
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
block_welford
(
x
,
mean
,
var
,
cur_count
,
max_count
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
}
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
)
+
epsilon
);
},
var
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_window
,
cast_tile
<
MeanDataType
>
(
mean
));
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
// x_window.foo();
// gamma_window.foo();
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
// layernorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
);
});
store_tile
(
y_window
,
y
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
deleted
100644 → 0
View file @
140d2fa6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
ThreadTile
,
// Sequence<...
typename
WarpTile
,
// Sequence<...
typename
BlockTile
>
// Sequence<...
struct
TileLayernorm2dShape
{
static
constexpr
index_t
kMPerThread
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerThread
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMPerWarp
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerWarp
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMThreadPerWarp
=
kMPerWarp
/
kMPerThread
;
static
constexpr
index_t
kNThreadPerWarp
=
kNPerWarp
/
kNPerThread
;
static
constexpr
index_t
kMPerBlock
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerBlock
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMWarpPerBlock
=
kMPerBlock
/
kMPerWarp
;
static
constexpr
index_t
kNWarpPerBlock
=
kNPerBlock
/
kNPerWarp
;
// TODO - kNNumWarps can only be 1 if we don't support cross warp welford
static_assert
(
kNWarpPerBlock
==
1
);
static
constexpr
index_t
kBlockSize
=
warpSize
*
kMWarpPerBlock
*
kNWarpPerBlock
;
};
}
// namespace ck_tile
include/ck_tile/ops/reduce/block/block_reduce.hpp
View file @
f221c2b0
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/welford.hpp
View file @
f221c2b0
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#include "ck_tile/ops/welford/block/block_welford.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/welford/block/block_welford.hpp
0 → 100644
View file @
f221c2b0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockWelford
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
XDataType
=
typename
Problem
::
XDataType
;
using
ComputeDataType
=
typename
Problem
::
ComputeDataType
;
CK_TILE_DEVICE
constexpr
BlockWelford
()
{}
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN welford will have different
// calculation of max_count_
// -> use block_welford_calculate_max_count to compute
template
<
typename
XDistributedTensor_
,
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
const
XDistributedTensor_
&
x_tensor
,
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
cur_count_
,
// -> prefer init as zero
const
int
&
max_count_
)
{
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
spans
=
XDistributedTensor_
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
I1
],
[
&
](
auto
dstr_idx_i1
)
{
if
(
cur_count_
<
max_count_
)
{
++
cur_count_
;
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
constexpr
auto
out_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
,
cur_count_
);
});
}
});
}
template
<
typename
XDistributedTensor_
>
CK_TILE_DEVICE
static
auto
MakeMeanVarBlockTile
()
{
static_assert
(
std
::
is_same_v
<
XDataType
,
typename
XDistributedTensor_
::
DataType
>
,
"wrong!"
);
constexpr
auto
reduce_dims
=
sequence
<
1
>
{};
constexpr
auto
dstr
=
make_static_tile_distribution
(
detail
::
make_reduce_tile_distribution_encoding
(
XDistributedTensor_
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
(),
reduce_dims
));
auto
tensor
=
make_static_distributed_tensor
<
ComputeDataType
>
(
dstr
);
return
tensor
;
}
template
<
typename
XDistributedTensor_
>
CK_TILE_DEVICE
auto
operator
()(
const
XDistributedTensor_
&
x_tensor
,
int
&
cur_count_
,
const
int
&
max_count_
)
{
auto
mean_tensor
=
MakeMeanVarBlockTile
<
XDistributedTensor_
>
();
auto
var_tensor
=
MakeMeanVarBlockTile
<
XDistributedTensor_
>
();
clear_tile
(
mean_tensor
);
clear_tile
(
var_tensor
);
(
*
this
)(
x_tensor
,
mean_tensor
,
var_tensor
,
cur_count_
,
max_count_
);
return
ck_tile
::
make_tuple
(
mean_tensor
,
var_tensor
);
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockWelfordSync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
)
{
using
Dstr
=
typename
MeanDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
static_assert
(
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
constexpr
index_t
NDimP
=
Dstr
::
get_num_of_dimension_p
();
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
const
int
original_count
=
count
;
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local_mean
=
mean_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_var
=
var_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_count
=
original_count
;
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
idim_p_lane
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// xor
index_t
src_lane
=
(
__lane_id
())
^
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// welford merge
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
);
});
}
});
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
count
=
v_local_count
;
});
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockWelfordCrossWarpSync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
static
constexpr
index_t
GetReduceWarps
()
{
constexpr
index_t
num_reduce_warps
=
[
&
]()
{
using
Dstr
=
typename
MeanDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_warp
=
0
;
index_t
len_
=
1
;
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_warp
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
len_
*=
r_length
;
}
});
return
len_
;
}();
return
num_reduce_warps
;
}
// return in byte
template
<
typename
MeanDistributedTensor_
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
// constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
// data need to exchange is very small, we just pack mean+var+count -> 4dword
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
// we need to store all data from every wave into smem
// e.g. 2x2 reduce along N
// -------------> reduce N
// | w0 | w1 | ___> | w01 |
// | w2 | w3 | | w23 |
//
// -> store data from every wave into LDS
//
//
// -------------> reduce N
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr
index_t
num_warps
=
BlockShape
::
BlockSize
/
warpSize
;
return
num_warps
*
4
*
thread_buf_size
*
sizeof
(
float
);
}
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
,
void
*
smem
)
{
using
DataType
=
typename
MeanDistributedTensor_
::
DataType
;
using
Dstr
=
typename
MeanDistributedTensor_
::
StaticTileDistribution
;
// using DstrEncode = typename Dstr::DstrEncode;
// using DstrEncodeDetail = typename DstrEncode::detail;
static_assert
(
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
// Note: we always pack everything into fp32x4
fp32x4_t
*
smem_ptr
=
reinterpret_cast
<
fp32x4_t
*>
(
smem
);
const
index_t
lane_id
=
get_lane_id
();
const
index_t
warp_id
=
get_warp_id
();
constexpr
auto
num_reduce_warps
=
GetReduceWarps
<
MeanDistributedTensor_
>
();
constexpr
index_t
num_warps
=
BlockShape
::
BlockSize
/
warpSize
;
const
index_t
smem_offset
=
warp_id
;
// skip if nonthing to do
if
constexpr
(
num_reduce_warps
==
1
)
return
;
// store into smem only for lane-0 within one warp
if
(
lane_id
==
0
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
fp32x4_t
local_scratch_
;
local_scratch_
[
0
]
=
bit_cast
<
float
>
(
mean_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
1
]
=
bit_cast
<
float
>
(
var_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
smem_ptr
[
smem_offset
+
i
*
num_warps
]
=
local_scratch_
;
});
}
block_sync_lds
();
// load from smem. here we let everythread to do compute :)
index_t
local_warp_id
=
warp_id
/
num_reduce_warps
;
index_t
local_smem_os
=
local_warp_id
*
num_reduce_warps
;
fp32x4_t
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
all_scratch
[
i_0
*
num_warps
+
i_1
]
=
smem_ptr
[
i_0
*
num_reduce_warps
+
local_smem_os
+
i_1
];
});
});
block_sync_lds
();
// TODO: we don't need sync here
// const int original_count = count;
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_count
=
bit_cast
<
int
>
(
v_local
[
2
]);
// further reduce mean/var
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
const
fp32x4_t
v_remote
=
all_scratch
[
i_0
*
num_warps
+
i_1
];
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
);
});
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
count
=
v_local_count
;
});
}
};
// compute the max count for a last dim reduce
// everything may have vector/repeat, so the max count could be uneven
// TODO: specify which dim to compute and proper set the problem
// TODO: BlockShape we reuse layernorm_fwd_shape :)
template
<
typename
BlockShape
>
CK_TILE_DEVICE
constexpr
index_t
block_tile_welford_calculate_max_count
(
int
row_size
)
{
#if 0
using S = BlockShape;
index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
index_t iNLane = get_thread_id() % NThread;
index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
return iN0 * S::Vector_N + iN3;
#endif
using
S_
=
BlockShape
;
constexpr
index_t
ThreadsPerBlock_N
=
S_
::
WarpPerBlock_N
*
S_
::
ThreadPerWarp_N
;
// TODO: we always check vector size, need be evenly devidable by vector-n
const
index_t
element_per_row
=
row_size
/
S_
::
Vector_N
;
index_t
lane_id_n
=
get_thread_id
()
%
ThreadsPerBlock_N
;
index_t
cnt
=
0
;
// TODO: Repeat_N can not be too long, otherwise this is not good
static_for
<
0
,
S_
::
Repeat_N
,
1
>
{}([
&
](
auto
)
{
index_t
_a
=
lane_id_n
<
element_per_row
?
1
:
0
;
cnt
+=
_a
;
lane_id_n
+=
ThreadsPerBlock_N
;
});
return
cnt
*
S_
::
Vector_N
;
}
// Note: this function must be called after all the computation
template
<
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
constexpr
void
block_tile_welford_post_scale_var
(
VarDistributedTensor_
&
var_tensor
,
int
count
)
{
using
DataType
=
typename
VarDistributedTensor_
::
DataType
;
tile_elementwise_inout
([
&
count
](
auto
&
x
)
{
x
=
x
/
type_convert
<
DataType
>
(
count
);
},
var_tensor
);
}
}
// namespace ck_tile
include/ck_tile/ops/welford/block/block_welford_problem.hpp
0 → 100644
View file @
f221c2b0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
XDataType_
,
typename
ComputeDataType_
,
typename
BlockShape_
>
struct
BlockWelfordProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/welford/thread/thread_welford.hpp
View file @
f221c2b0
...
@@ -7,95 +7,30 @@
...
@@ -7,95 +7,30 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
ComputeDataType_
,
typename
XDataType_
>
template
<
typename
T
>
struct
ThreadWelford
CK_TILE_DEVICE
void
welford_update
(
T
&
mean
,
T
&
var
,
T
x
,
int
count
)
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
// TODO: check nan? maybe no
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
T
delta
=
x
-
mean
;
mean
+=
delta
/
count
;
template
<
typename
T
>
T
delta2
=
x
-
mean
;
CK_TILE_DEVICE
void
Update
(
T
&
mean
,
T
&
var
,
T
x
)
var
+=
delta
*
delta2
;
{
}
if
(
ck_tile
::
isnan
(
x
))
{
template
<
typename
T
>
mean
=
x
;
CK_TILE_DEVICE
static
void
var
=
x
;
welford_merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
}
{
else
int
count
=
count_a
+
count_b
;
{
T
count_
=
type_convert
<
T
>
(
count
);
T
delta
=
x
-
mean
;
T
count_a_
=
type_convert
<
T
>
(
count_a
);
mean
+=
delta
/
cur_count_
;
T
count_b_
=
type_convert
<
T
>
(
count_b
);
T
delta2
=
x
-
mean
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
count_b_
/
count_
;
var
+=
delta
*
delta2
;
}
T
delta
=
mean_b
-
mean_a
;
}
mean_a
+=
delta
*
count_b_over_count
;
var_a
+=
var_b
+
delta
*
delta
*
count_a_
*
count_b_over_count
;
// [CAUSION] - max_count_ is to deal with the padding problem
count_a
=
count
;
// max_count_ is depend on caller, eg: naive and splitN welford will have different
}
// calculation of max_count_
CK_TILE_DEVICE
constexpr
ThreadWelford
(
int
max_count
)
:
cur_count_
(
0
),
max_count_
(
max_count
)
{}
template
<
typename
XDistributedTensor_
,
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
const
XDistributedTensor_
&
x_tensor
,
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
)
{
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
spans
=
XDistributedTensor_
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
I1
],
[
&
](
auto
dstr_idx_i1
)
{
if
(
cur_count_
<
max_count_
)
{
++
cur_count_
;
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
constexpr
auto
out_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
Update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
);
});
}
});
}
template
<
typename
XDistributedTensor_
>
CK_TILE_DEVICE
static
auto
MakeInitialMeanVarDistributedTensor
()
{
static_assert
(
std
::
is_same_v
<
XDataType
,
typename
XDistributedTensor_
::
DataType
>
,
"wrong!"
);
constexpr
auto
reduce_dims
=
sequence
<
1
>
{};
constexpr
auto
dstr
=
make_static_tile_distribution
(
detail
::
make_reduce_tile_distribution_encoding
(
XDistributedTensor_
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
(),
reduce_dims
));
auto
tensor
=
make_static_distributed_tensor
<
ComputeDataType
>
(
dstr
);
clear_tile
(
tensor
);
return
tensor
;
}
template
<
typename
XDistributedTensor_
>
CK_TILE_DEVICE
auto
operator
()(
const
XDistributedTensor_
&
x_tensor
)
{
auto
mean_tensor
=
MakeInitialMeanVarDistributedTensor
<
XDistributedTensor_
>
();
auto
var_tensor
=
MakeInitialMeanVarDistributedTensor
<
XDistributedTensor_
>
();
(
*
this
)(
x_tensor
,
mean_tensor
,
var_tensor
);
return
ck_tile
::
make_tuple
(
mean_tensor
,
var_tensor
);
}
int
cur_count_
;
int
max_count_
;
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/welford/warp/warp_welford.hpp
deleted
100644 → 0
View file @
140d2fa6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
ComputeDataType_
,
bool
BroadcastLane
=
true
,
bool
GetActualVariance
=
true
>
struct
WarpMergeWelford
{
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
template
<
typename
T
>
CK_TILE_DEVICE
static
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
{
int
count
=
count_a
+
count_b
;
T
count_
=
type_convert
<
T
>
(
count
);
T
count_a_
=
type_convert
<
T
>
(
count_a
);
T
count_b_
=
type_convert
<
T
>
(
count_b
);
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
count_b_
/
count_
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
var_a
+=
var_b
+
delta
*
delta
*
count_a_
*
count_b_over_count
;
count_a
=
count
;
}
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
)
{
using
Dstr
=
typename
MeanDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
static_assert
(
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
constexpr
index_t
NDimP
=
Dstr
::
get_num_of_dimension_p
();
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
const
auto
ps_idx
=
make_array
<
index_t
>
(
get_warp_id
(),
get_lane_id
());
const
auto
rs_idx
=
mean_tensor
.
get_tile_distribution
().
calculate_rs_index_from_ps_index
(
ps_idx
);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
const
int
original_count
=
count
;
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local_mean
=
mean_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_var
=
var_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_count
=
original_count
;
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
idim_p_lane
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
constexpr
index_t
lid_delta
=
lid_over_rid_derivative
*
(
1
<<
(
nstage
-
istage
-
1
));
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle_down
(
v_local_mean
,
lid_delta
);
const
auto
v_remote_var
=
warp_shuffle_down
(
v_local_var
,
lid_delta
);
const
auto
v_remote_count
=
warp_shuffle_down
(
v_local_count
,
lid_delta
);
// welford merge
Merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
);
});
}
});
// cross-lane broadcast for replication
// only broadcast on R dimension correspond to lane
// (lane id maps to this R dimension)
if
constexpr
(
BroadcastLane
)
{
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
const
index_t
r_id
=
rs_idx
[
idim_r
];
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
NDimP
-
1
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// broadcast sweep backward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// do I hold reduced data?
const
bool
do_i_hold_reduced_data
=
r_id
<
(
1
<<
istage
);
constexpr
index_t
lid_delta
=
lid_over_rid_derivative
*
(
1
<<
istage
);
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle_up
(
v_local_mean
,
lid_delta
);
const
auto
v_remote_var
=
warp_shuffle_up
(
v_local_var
,
lid_delta
);
const
auto
v_remote_count
=
warp_shuffle_up
(
v_local_count
,
lid_delta
);
// decide whether to update local data with remote data
v_local_mean
=
do_i_hold_reduced_data
?
v_local_mean
:
v_remote_mean
;
v_local_var
=
do_i_hold_reduced_data
?
v_local_var
:
v_remote_var
;
v_local_count
=
do_i_hold_reduced_data
?
v_local_count
:
v_remote_count
;
});
}
});
}
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
if
constexpr
(
GetActualVariance
)
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
/
v_local_count
;
else
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
count
=
v_local_count
;
});
}
};
}
// namespace ck_tile
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
View file @
f221c2b0
...
@@ -15,8 +15,9 @@ namespace instance {
...
@@ -15,8 +15,9 @@ namespace instance {
using
namespace
ck
::
tensor_layout
::
convolution
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
...
@@ -45,17 +46,42 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std
...
@@ -45,17 +46,42 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
// clang-format on
// clang-format on
>
;
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
>
// clang-format on
>
;
// NGCHW requires transpose, we use vector loads and stores params for them
// NGCHW requires transpose, we use vector loads and stores params for them
template
<
ck
::
index_t
NDimSpatial
,
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
ALayout
,
...
@@ -96,6 +122,45 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
...
@@ -96,6 +122,45 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
// clang-format on
// clang-format on
>
;
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
BF16
,
BF16
,
1
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
BF16
,
BF16
,
2
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
BF16
,
BF16
,
4
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
4
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
BF16
,
BF16
,
8
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
BF16
,
BF16
,
2
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
BF16
,
BF16
,
4
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
BF16
,
BF16
,
8
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
BF16
,
BF16
,
1
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
BF16
,
BF16
,
1
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
4
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
BF16
,
BF16
,
1
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
BF16
,
BF16
,
1
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
BF16
,
BF16
,
1
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
BF16
,
BF16
,
2
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
BF16
,
BF16
,
4
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
4
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
BF16
,
BF16
,
8
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
BF16
,
BF16
,
4
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
BF16
,
BF16
,
8
,
1
>
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
View file @
f221c2b0
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...
@@ -113,7 +113,7 @@ template <ck::index_t NDimSpatial,
...
@@ -113,7 +113,7 @@ template <ck::index_t NDimSpatial,
typename
BLayout
,
typename
BLayout
,
typename
ELayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
ConvolutionBackwardWeightSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
using
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_
f32_bf16_
instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
...
@@ -141,6 +141,41 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple<
...
@@ -141,6 +141,41 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
// generic instance
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
2
>
,
// instance for small conv.K
// for bf16 conv.K and conv.C must be divisible by 2
// since half_t atomic_add require scalar_per_x_vector % 2 == 0
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
4
,
8
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
2
,
1
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
2
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
4
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
2
,
2
,
true
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
S
<
1
,
4
,
32
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
8
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
S
<
1
,
4
,
8
,
8
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
1
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
8
,
8
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
1
,
true
,
S
<
1
,
4
,
16
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
4
,
8
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
1
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
4
,
8
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
1
,
true
,
S
<
1
,
4
,
16
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
S
<
1
,
4
,
4
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
4
,
4
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
2
,
true
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
8
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
f221c2b0
...
@@ -367,6 +367,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -367,6 +367,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
op_ptrs
);
op_ptrs
);
}
}
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instances
(
op_ptrs
);
}
#endif
#endif
}
}
if
constexpr
(
is_same_v
<
InLayout
,
NGCHW
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
if
constexpr
(
is_same_v
<
InLayout
,
NGCHW
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
...
@@ -382,6 +395,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -382,6 +395,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances
(
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances
(
op_ptrs
);
}
#endif
#endif
}
}
}
}
...
@@ -453,6 +479,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -453,6 +479,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
op_ptrs
);
op_ptrs
);
}
}
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances
(
op_ptrs
);
}
#endif
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
...
@@ -477,6 +516,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -477,6 +516,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances
(
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeA
,
ck
::
bhalf_t
>
&&
is_same_v
<
ComputeTypeB
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances
(
op_ptrs
);
}
#endif
#endif
}
}
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
View file @
f221c2b0
...
@@ -89,6 +89,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
...
@@ -89,6 +89,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
NHWGC
,
...
@@ -100,6 +112,53 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in
...
@@ -100,6 +112,53 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
...
@@ -215,6 +274,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
...
@@ -215,6 +274,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
NDHWGC
,
...
@@ -226,6 +297,53 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16
...
@@ -226,6 +297,53 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
...
...
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp
View file @
f221c2b0
...
@@ -44,8 +44,11 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tu
...
@@ -44,8 +44,11 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tu
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
// Can we support this kind of odd case? 224(256) = 28*8 + (4*8)
// Can we support this kind of odd case? 224(256) = 28*8 + (4*8)
//DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
//DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
...
@@ -64,10 +67,13 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tup
...
@@ -64,10 +67,13 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tup
// Latency friendly
// Latency friendly
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
// Memory friendly
// Memory friendly
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
8
,
2
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
8
,
2
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
2
,
2
,
16
,
16
,
4
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
64
,
8
,
4
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
64
,
8
,
4
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
16
,
64
,
4
,
4
,
16
,
16
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
16
,
64
,
4
,
4
,
16
,
16
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
...
@@ -75,7 +81,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tup
...
@@ -75,7 +81,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tup
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
64
,
4
,
4
,
16
,
16
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
64
,
4
,
4
,
16
,
16
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
64
,
4
,
4
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
64
,
4
,
4
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
2
,
4
,
16
,
16
,
1
,
4
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
2
,
4
,
16
,
16
,
1
,
4
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
2
,
2
,
16
,
16
,
1
,
4
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
// clang-format on
// clang-format on
>
;
>
;
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp
View file @
f221c2b0
...
@@ -44,13 +44,21 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu
...
@@ -44,13 +44,21 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu
// Compute friendly
// Compute friendly
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
8
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
8
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
8
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
8
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
8
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
8
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
4
,
4
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
224
,
64
,
8
,
8
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
224
,
64
,
8
,
8
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -64,18 +72,23 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances = std::tup
...
@@ -64,18 +72,23 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances = std::tup
// Latency friendly
// Latency friendly
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
2
,
2
,
16
,
16
,
1
,
1
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
// Memory friendly
// Memory friendly
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
16
,
64
,
4
,
8
,
16
,
16
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
16
,
64
,
4
,
8
,
16
,
16
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
64
,
16
,
64
,
4
,
4
,
16
,
16
,
2
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
32
,
64
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
64
,
4
,
8
,
16
,
16
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
64
,
4
,
8
,
16
,
16
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
64
,
4
,
4
,
16
,
16
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
64
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
64
,
4
,
8
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
2
,
8
,
16
,
16
,
1
,
4
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
2
,
8
,
16
,
16
,
1
,
4
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGemm_Xdl_CShuffleV3
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
64
,
2
,
2
,
16
,
16
,
1
,
4
,
S
<
32
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
0
,
S
<
32
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
// clang-format on
// clang-format on
>
;
>
;
}
// namespace instance
}
// namespace instance
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment