Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
49c39b51
Commit
49c39b51
authored
Nov 05, 2024
by
carlushuang
Browse files
moe pipeline
parent
03c6448b
Changes
26
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2857 additions
and
29 deletions
+2857
-29
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp
..._tile/06_permute/alternative_impl/matrix_core_swizzle.cpp
+2
-2
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
...6_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
+6
-6
example/ck_tile/06_permute/permute.cpp
example/ck_tile/06_permute/permute.cpp
+1
-1
example/ck_tile/15_fused_moe/fused_moegemm.hpp
example/ck_tile/15_fused_moe/fused_moegemm.hpp
+66
-0
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+415
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+112
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+98
-6
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+27
-9
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+42
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+67
-0
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+52
-0
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+52
-3
include/ck_tile/host/reference/reference_permute.hpp
include/ck_tile/host/reference/reference_permute.hpp
+22
-2
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+14
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+362
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
+124
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
...e/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
+33
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
.../ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
+572
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+744
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
+46
-0
No files found.
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp
View file @
49c39b51
...
@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
...
@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else
if
(
t
.
permute
.
compare
(
"0,1,3,4,2,5"
)
==
0
)
else
if
(
t
.
permute
.
compare
(
"0,1,3,4,2,5"
)
==
0
)
{
{
constexpr
matrix_core_permute_style
pstyle
=
constexpr
matrix_core_permute_style
pstyle
=
matrix_core_permute_style
::
permute_
b_nr_kr_kw_nw_kv
;
matrix_core_permute_style
::
b_nr_kr_kw_nw_kv
;
using
Kernel
=
using
Kernel
=
matrix_core_swizzle_kernel
<
BLOCK_SIZE
,
NPerBlock
,
KPerBlock
,
pstyle
,
Inst
>
;
matrix_core_swizzle_kernel
<
BLOCK_SIZE
,
NPerBlock
,
KPerBlock
,
pstyle
,
Inst
>
;
...
@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
...
@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else
if
(
t
.
permute
.
compare
(
"0,1,3,4,2,5"
)
==
0
)
else
if
(
t
.
permute
.
compare
(
"0,1,3,4,2,5"
)
==
0
)
{
{
constexpr
matrix_core_permute_style
pstyle
=
constexpr
matrix_core_permute_style
pstyle
=
matrix_core_permute_style
::
permute_
b_nr_kr_kw_nw_kv
;
matrix_core_permute_style
::
b_nr_kr_kw_nw_kv
;
using
Kernel
=
using
Kernel
=
matrix_core_swizzle_kernel
<
BLOCK_SIZE
,
NPerBlock
,
KPerBlock
,
pstyle
,
Inst
>
;
matrix_core_swizzle_kernel
<
BLOCK_SIZE
,
NPerBlock
,
KPerBlock
,
pstyle
,
Inst
>
;
...
...
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
View file @
49c39b51
...
@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
...
@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
{
{
permute_b_n0_k0_n1_k1_n2_k2
=
0
,
// 0,1,4,2,5,3,6
permute_b_n0_k0_n1_k1_n2_k2
=
0
,
// 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2
=
1
,
// 0,1,2,4,5,3,6
permute_b_n0_n1_k0_k1_n2_k2
=
1
,
// 0,1,2,4,5,3,6
permute_
b_nr_kr_kw_nw_kv
=
2
,
// 0,1,3,4,2,5
b_nr_kr_kw_nw_kv
=
2
,
// 0,1,3,4,2,5
permute_
b_nr_kr_waveflatten
=
permute_
b_nr_kr_kw_nw_kv
,
b_nr_kr_waveflatten
=
b_nr_kr_kw_nw_kv
,
};
};
// assume this is B matrix, originally we have batch*n*k
// assume this is B matrix, originally we have batch*n*k
...
@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
...
@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
else
else
{
{
// clang-format off
// clang-format off
//
permute_
b_nr_kr_kw_nw_kv or
permute_
b_nr_kr_waveflatten
// b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
constexpr
index_t
Kv
=
Alignment
;
constexpr
index_t
Kv
=
Alignment
;
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
...
@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
...
@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
tmp_1
;
return
tmp_1
;
#else
#else
//
permute_
b_nr_kr_waveflatten =
permute_
b_nr_kr_kw_nw_kv,
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
...
@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
...
@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
else
else
{
{
#if MERGE_2D_013425
#if MERGE_2D_013425
//
permute_
b_nr_kr_waveflatten =
permute_
b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
return
make_tile_window
(
dst_view
,
return
make_tile_window
(
dst_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
i_n
*
NPerBlock
,
i_k
*
KPerBlock
},
{
i_n
*
NPerBlock
,
i_k
*
KPerBlock
},
get_dst_dist
());
get_dst_dist
());
#else
#else
//
permute_
b_nr_kr_waveflatten =
permute_
b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
...
...
example/ck_tile/06_permute/permute.cpp
View file @
49c39b51
...
@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
{
if
(
arg_parser
.
get_str
(
"perm"
)
==
std
::
string
(
"0,1,3,4,2,5"
))
if
(
arg_parser
.
get_str
(
"perm"
)
==
std
::
string
(
"0,1,3,4,2,5"
))
{
{
//
permute_
b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits
t
;
matrix_core_swizzle_traits
t
;
t
.
data_type
=
data_type
;
t
.
data_type
=
data_type
;
t
.
permute
=
arg_parser
.
get_str
(
"perm"
);
t
.
permute
=
arg_parser
.
get_str
(
"perm"
);
...
...
example/ck_tile/15_fused_moe/fused_moegemm.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
// this is only a convenient structure for creating an example
// this is not part of the host API
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
;
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
;
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
GDataType
=
ck_tile
::
bf16_t
;
using
DDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
W0ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
W1ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
;
{
using
ADataType
=
ck_tile
::
int8_t
;
using
GDataType
=
ck_tile
::
int8_t
;
using
DDataType
=
ck_tile
::
int8_t
;
using
AccDataType
=
int32_t
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
W0ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
W1ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
// runtime args
struct
fused_moegemm_args
:
public
ck_tile
::
Layernorm2dFwdHostArgs
{
};
// This is the public API, will be generated by script
struct
fused_moegemm_traits
{
std
::
string
prec_i
;
// input precision
std
::
string
prec_w
;
// weight precision
std
::
string
prec_o
;
// output precision
std
::
string
prec_st
;
// token scale data type
std
::
string
prec_sw
;
// weight scale data type
std
::
string
prec_sq
;
// smooth quant scale
std
::
string
prec_kw
;
// topk-weight data type
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float
fused_moegemm
(
fused_moegemm_traits
,
fused_moegemm_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/15_fused_moe/main.cpp
0 → 100644
View file @
49c39b51
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include <algorithm>
#include <cstring>
// different threshold for different dtype
template
<
typename
DataType
>
auto
get_elimit
()
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
()
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
// mfma_type, 0:32x32, 1:16x16
template
<
typename
H
>
auto
shuffle_moe_weight
(
const
H
&
t
,
std
::
string
mfma_dtype
,
int
mfma_type
=
0
)
{
static_assert
(
t
.
get_lengths
().
size
()
==
3
);
int
b_
=
t
.
get_lengths
()[
0
];
int
n_
=
t
.
get_lengths
()[
1
];
int
k_
=
t
.
get_lengths
()[
2
];
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
0
)
{
std
::
vector
<
ck_tile
::
index_t
>
new_lens
{
b_
,
n_
/
32
,
32
,
k_
/
16
,
2
,
8
};
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"t"
,
"128"
,
"num input tokens"
)
.
insert
(
"e"
,
"32"
,
"num of experts"
)
.
insert
(
"k"
,
"5"
,
"topk"
)
.
insert
(
"h"
,
"8192"
,
"hidden_size of this model"
)
.
insert
(
"i"
,
"8192"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
.
insert
(
"bm"
,
"32"
,
"blocking factor for sorted tokens"
)
.
insert
(
"tp"
,
"8"
,
"tensor parallel size"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec_i"
,
"bf16"
,
"input precision"
)
.
insert
(
"prec_w"
,
"bf16"
,
"weight precision"
)
.
insert
(
"prec_o"
,
"bf16"
,
"output precision"
)
.
insert
(
"prec_st"
,
"auto"
,
"token scale data type. auto will set to fp32"
)
.
insert
(
"prec_sw"
,
"auto"
,
"weight scale data type. auto will set to fp32"
)
.
insert
(
"prec_sq"
,
"auto"
,
"(dynamic) smooth quant data type. auto will set to fp32"
)
.
insert
(
"prec_kw"
,
"auto"
,
"topk-weight data type. auto will set to fp32"
)
.
insert
(
"fquant"
,
"0"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"gonly"
,
"0"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
.
insert
(
"balance"
,
"1"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, SQ:smooth-quant-type, KW:topk-weight-type
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
tokens
=
arg_parser
.
get_int
(
"t"
);
ck_tile
::
index_t
experts
=
arg_parser
.
get_int
(
"e"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
hidden_size
=
arg_parser
.
get_int
(
"h"
);
ck_tile
::
index_t
intermediate_size
=
arg_parser
.
get_int
(
"i"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
block_m
=
arg_parser
.
get_int
(
"bm"
);
if
(
stride
<
0
)
stride
=
hidden_size
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
gonly
=
arg_parser
.
get_int
(
"gonly"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
ck_tile
::
index_t
shared_intermediate_size
=
intermediate_size
*
(
gonly
?
1
:
2
)
/
tp
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
AccDataType
=
typename
TypeConfig
::
AccDataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
W0ScaleDataType
=
typename
TypeConfig
::
W0ScaleDataType
;
using
W1ScaleDataType
=
typename
TypeConfig
::
W1ScaleDataType
;
using
YSmoothScaleDataType
=
typename
TypeConfig
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
TypeConfig
::
TopkWeightDataType
;
using
IndexDataType
=
typename
TypeConfig
::
IndexDataType
;
// host verify
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ADataType
>
g_host
({
e
,
shared_intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
ADataType
>
d_host
({
e
,
intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
XResidualDataType
>
x_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
m
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_ref
({
m
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_dev
({
m
});
ck_tile
::
HostTensor
<
XScaleDataType
>
x_scale_host
({
n
});
ck_tile
::
HostTensor
<
XScaleDataType
>
x_scale_host_dev
({
n
});
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
XResidualDataType
>
{
-
.5
f
,
.5
f
}(
x_residual_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
-
1.
f
,
1.
f
}(
x_scale_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
ck_tile
::
DeviceMem
x_buf
(
a_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_scale_buf
(
y_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_scale_buf
(
x_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_residual_buf
(
x_residual_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_residual_buf
(
y_residual_host
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
a_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
x_residual_buf
.
ToDevice
(
x_residual_host
.
data
());
x_scale_buf
.
ToDevice
(
x_scale_host
.
data
());
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_o
)
{
base_str
+=
"|"
+
prec_o
;
}
if
(
fused_quant
==
1
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_sy
+
")"
;
}
return
base_str
;
}();
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
layernorm2d_fwd_traits
traits
{
prec_i
,
prec_o
,
prec_sx
,
prec_sy
,
SaveMeanVar
,
fused_add
,
fused_quant
};
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
fused_add
!=
0
?
x_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
x_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
fused_add
==
1
?
y_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
y_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
nullptr
,
// p_mean, unsupported yet
nullptr
,
// p_invStd, unsupported yet
epsilon
,
m
,
n
,
stride
};
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
BetaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
// reference
if
(
fused_add
!=
0
)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to a_host for simplcity here...
std
::
transform
(
a_host
.
mData
.
cbegin
(),
a_host
.
mData
.
cend
(),
x_residual_host
.
mData
.
cbegin
(),
a_host
.
mData
.
begin
(),
[](
auto
x_
,
auto
r_
)
{
auto
o_
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_
)
+
ck_tile
::
type_convert
<
ComputeDataType
>
(
r_
);
return
ck_tile
::
type_convert
<
ADataType
>
(
o_
);
});
}
ck_tile
::
reference_layernorm2d_fwd
<
ADataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
a_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
if
(
fused_quant
!=
0
)
{
auto
dquant_functor
=
[
&
](
int
m_
,
auto
&
o_
,
auto
&
acc_
)
{
int
N_
=
acc_
.
mDesc
.
get_lengths
()[
1
];
if
(
fused_quant
==
1
)
{
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
// input smooth outlier
acc_
(
m_
,
n_
)
=
acc_
(
m_
,
n_
)
*
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_scale_host
(
n_
));
}
}
ComputeDataType
absmax
=
static_cast
<
ComputeDataType
>
(
0
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
const
auto
a
=
ck_tile
::
abs
(
acc_
(
m_
,
n_
));
absmax
=
a
>
absmax
?
a
:
absmax
;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType
y_scale
=
absmax
/
static_cast
<
ComputeDataType
>
(
127.0
);
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
o_
(
m_
,
n_
)
=
ck_tile
::
type_convert
<
YDataType
>
(
acc_
(
m_
,
n_
)
/
y_scale
);
}
};
ck_tile
::
reference_layernorm2d_fwd
<
ADataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
a_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
,
dquant_functor
);
}
else
{
ck_tile
::
reference_layernorm2d_fwd
<
ADataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
a_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
}
y_buf
.
FromDevice
(
y_host_dev
.
data
());
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host_dev
({
m
,
n
},
{
stride
,
1
});
if
(
fused_add
==
1
)
{
y_residual_buf
.
FromDevice
(
y_residual_host_dev
.
data
());
}
auto
[
rtol
,
atol
]
=
get_elimit
<
InDataType
>
();
if
(
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
if
(
fused_add
==
1
)
{
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev
,
a_host
,
std
::
string
(
"ADD Error: Incorrect results!"
),
rtol
,
atol
);
}
}
else
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
std
::
vector
<
YDataType
>
y_host_dev_row
(
y_host_dev
.
begin
()
+
i_r
*
stride
,
y_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
YDataType
>
y_host_ref_row
(
y_host_ref
.
begin
()
+
i_r
*
stride
,
y_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
y_host_dev_row
,
y_host_ref_row
,
std
::
string
(
"OUT["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
if
(
fused_add
==
1
)
{
std
::
vector
<
YResidualDataType
>
y_residual_host_dev_row
(
y_residual_host_dev
.
begin
()
+
i_r
*
stride
,
y_residual_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
YResidualDataType
>
y_residual_host_ref_row
(
a_host
.
begin
()
+
i_r
*
stride
,
a_host
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev_row
,
y_residual_host_ref_row
,
std
::
string
(
"ADD["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
if
(
fused_quant
==
1
)
{
y_scale_buf
.
FromDevice
(
y_scale_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
y_scale_host_dev
,
y_scale_host_ref
,
std
::
string
(
"SCALE Error: Incorrect results!"
),
rtol
,
atol
);
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
return
pass
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_sx
=
arg_parser
.
get_str
(
"prec_sx"
);
std
::
string
prec_sy
=
arg_parser
.
get_str
(
"prec_sy"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
if
(
prec_sx
==
"auto"
)
{
prec_sx
=
"fp32"
;
}
if
(
prec_sy
==
"auto"
)
{
prec_sy
=
"fp32"
;
}
int
save_mv
=
arg_parser
.
get_int
(
"save_mv"
);
// no dynamic quant case
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
// dynamic quant case, only in inference
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"int8"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
int8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"int8"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
49c39b51
...
@@ -53,6 +53,11 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
...
@@ -53,6 +53,11 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
// clang-format on
// clang-format on
}
// namespace impl
}
// namespace impl
// TODO: this is hot-tmp fix to unblock user case. Need refactor into template arg
#ifndef CK_TILE_BUFFER_LOAD_AGPR
#define CK_TILE_BUFFER_LOAD_AGPR 0
#endif
// TODO: glc/slc/...
// TODO: glc/slc/...
template
<
index_t
bytes
,
bool
pre_nop
=
false
>
template
<
index_t
bytes
,
bool
pre_nop
=
false
>
struct
buffer_load
;
struct
buffer_load
;
...
@@ -74,6 +79,19 @@ struct buffer_load<16, pre_nop>
...
@@ -74,6 +79,19 @@ struct buffer_load<16, pre_nop>
{
{
static_assert
(
sizeof
(
T
)
==
16
);
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
#if CK_TILE_BUFFER_LOAD_AGPR
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
"=a"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
"=a"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
#else
if
constexpr
(
pre_nop
)
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
...
@@ -85,6 +103,7 @@ struct buffer_load<16, pre_nop>
...
@@ -85,6 +103,7 @@ struct buffer_load<16, pre_nop>
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
:
"memory"
);
#endif
}
}
};
};
...
@@ -621,6 +640,60 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
...
@@ -621,6 +640,60 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
}
template
<
typename
scalar_type
,
index_t
N
,
bool
pre_nop
=
false
>
struct
buffer_atomic_add_if
;
template
<
bool
pre_nop
>
struct
buffer_atomic_add_if
<
bf16_t
,
2
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
/*s_offset*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3
\n
"
"s_mov_b64 exec %5"
:
:
"v"
(
v_offset
),
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"s"
(
res
.
xy
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
:
"memory"
);
}
};
template
<
typename
scalar_type
,
index_t
N
,
bool
pre_nop
=
false
>
struct
buffer_atomic_add
;
template
<
bool
pre_nop
>
struct
buffer_atomic_add
<
bf16_t
,
2
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
/*s_offset*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag = 1*/
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
:
"v"
(
v_offset
),
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"s"
(
res
.
xy
),
"n"
(
i_offset
)
:
"memory"
);
}
};
namespace
impl
{
namespace
impl
{
// below type indicate the data type used for buffer load inline asm
// below type indicate the data type used for buffer load inline asm
// clang-format off
// clang-format off
...
@@ -2378,6 +2451,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
...
@@ -2378,6 +2451,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif
#endif
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_atomic_add_raw
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
index_t
dst_linear_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
,
bool_constant
<
pre_nop
>
=
{})
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
index_t
dst_linear_addr_offset
=
dst_linear_element_offset
*
sizeof
(
T
);
if
constexpr
(
oob_conditional_check
)
{
buffer_atomic_add_if
<
T
,
N
,
pre_nop
>
{}(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
,
dst_linear_addr_offset
,
dst_thread_element_valid
);
}
else
{
buffer_atomic_add
<
T
,
N
,
pre_nop
>
{}(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
,
dst_linear_addr_offset
,
1
);
}
}
// buffer_atomic_max requires:
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// 2) p_dst_wave must be a wavewise pointer.
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
49c39b51
...
@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
...
@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
template
<
memory_operation_enum
Op
,
typename
X
,
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
{
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
set
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
{
this
->
template
atomic_add
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
atomic_add
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
{
this
->
template
atomic_max
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
this
->
template
atomic_max
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
// FIXME: remove memory_operation_enum::add
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
auto
tmp
=
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
this
->
template
get
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
// tmp += x;
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
// this->template set<X>(i, is_valid_element, tmp);
}
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set_raw
<
X
,
oob_conditional_check
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
this
->
template
atomic_add_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
}
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
...
@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
}
}
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
...
@@ -585,6 +626,57 @@ struct buffer_view<address_space_enum::global,
...
@@ -585,6 +626,57 @@ struct buffer_view<address_space_enum::global,
}
}
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_add_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
using
scalar_t
=
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
;
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
get_address_space
()
==
address_space_enum
::
global
,
"only support global mem"
);
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
||
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bf16_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bf16_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
x
,
p_data_
,
i
,
linear_offset
,
is_valid_element
,
buffer_size_
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
49c39b51
...
@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
...
@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
load
(
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
load
(
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
DistributedTensor_
,
template
<
typename
DistributedTensor_
,
...
@@ -51,15 +55,17 @@ template <typename DistributedTensor_,
...
@@ -51,15 +55,17 @@ template <typename DistributedTensor_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
load
(
dst_tile
,
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
dst_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
/**
/**
...
@@ -76,6 +82,7 @@ template <typename T,
...
@@ -76,6 +82,7 @@ template <typename T,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
@@ -83,11 +90,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
...
@@ -83,11 +90,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
tile_window
.
load_raw
(
tile_window
.
load_raw
(
tile
,
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
T
,
template
<
typename
T
,
...
@@ -95,6 +103,7 @@ template <typename T,
...
@@ -95,6 +103,7 @@ template <typename T,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
@@ -102,11 +111,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
...
@@ -102,11 +111,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
tile_window
.
load_raw
(
tile_window
.
load_raw
(
tile
,
number
<
-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
...
@@ -114,6 +124,7 @@ template <typename LdsTileWindow_,
...
@@ -114,6 +124,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
auto
...
@@ -122,11 +133,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
...
@@ -122,11 +133,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
return
tile_window
.
async_load_raw
(
return
tile_window
.
async_load_raw
(
lds_tile
,
lds_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
...
@@ -134,6 +148,7 @@ template <typename LdsTileWindow_,
...
@@ -134,6 +148,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
...
@@ -141,11 +156,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
...
@@ -141,11 +156,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
return
tile_window
.
async_load_raw
(
return
tile_window
.
async_load_raw
(
lds_tile
,
lds_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
...
...
include/ck_tile/core/tensor/tensor_view.hpp
View file @
49c39b51
...
@@ -333,6 +333,48 @@ struct tensor_view
...
@@ -333,6 +333,48 @@ struct tensor_view
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
CK_TILE_HOST_DEVICE
void
print
()
const
{
{
printf
(
"tensor_view{"
);
printf
(
"tensor_view{"
);
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
49c39b51
...
@@ -785,6 +785,73 @@ struct tile_window_with_static_distribution
...
@@ -785,6 +785,73 @@ struct tile_window_with_static_distribution
});
});
}
}
template
<
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
>
CK_TILE_DEVICE
void
update_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements_raw
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// move thread's botom tensor coordiante
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// [x0', x1', ... ] ==> [offset]
// also move window-origin
// also move window-origin
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
49c39b51
...
@@ -849,6 +849,58 @@ struct tile_window_linear
...
@@ -849,6 +849,58 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE
();
WINDOW_DISPATCH_ISSUE
();
}
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
update_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements_raw
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
};
WINDOW_DISPATCH_ISSUE
();
}
// move thread's botom tensor coordiante
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// [x0', x1', ... ] ==> [offset]
// also move window-origin
// also move window-origin
...
...
include/ck_tile/core/tensor/update_tile.hpp
View file @
49c39b51
...
@@ -41,15 +41,64 @@ template <typename BottomTensorView_,
...
@@ -41,15 +41,64 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
{
tile_window
.
update
(
dstr_tensor
);
tile_window
.
update
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
update_tile_raw
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
update_raw
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
update_tile_raw
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
update_raw
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/host/reference/reference_permute.hpp
View file @
49c39b51
...
@@ -16,7 +16,7 @@ namespace ck_tile {
...
@@ -16,7 +16,7 @@ namespace ck_tile {
*/
*/
template
<
typename
DataType
>
template
<
typename
DataType
>
CK_TILE_HOST
void
CK_TILE_HOST
void
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
dims
)
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
perm
)
{
{
const
auto
x_len
=
x
.
mDesc
.
get_lengths
();
const
auto
x_len
=
x
.
mDesc
.
get_lengths
();
const
auto
y_len
=
y
.
mDesc
.
get_lengths
();
const
auto
y_len
=
y
.
mDesc
.
get_lengths
();
...
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
...
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
std
::
vector
<
size_t
>
tmp
(
rank
,
0
);
std
::
vector
<
size_t
>
tmp
(
rank
,
0
);
for
(
index_t
i
=
0
;
i
<
rank
;
i
++
)
for
(
index_t
i
=
0
;
i
<
rank
;
i
++
)
{
{
tmp
[
dims
[
i
]]
=
y_coord
[
i
];
tmp
[
perm
[
i
]]
=
y_coord
[
i
];
}
}
return
tmp
;
return
tmp
;
}();
}();
...
@@ -54,4 +54,24 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
...
@@ -54,4 +54,24 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
make_ParallelTensorFunctor
(
f
,
x_elm
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
x_elm
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
DataType
>
CK_TILE_HOST
auto
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
std
::
vector
<
index_t
>
perm
)
{
auto
x_shape
=
x
.
get_lengths
();
ck_tile
::
index_t
rank
=
perm
.
size
();
std
::
vector
<
ck_tile
::
index_t
>
y_shape
=
[
&
]()
{
std
::
vector
<
ck_tile
::
index_t
>
tmp
(
rank
,
0
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
rank
);
i
++
)
{
tmp
[
i
]
=
x_shape
[
perm
[
i
]];
}
return
tmp
;
}();
HostTensor
<
DataType
>
y
(
y_shape
);
reference_permute
(
x
,
y
,
perm
);
return
y
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fused_moe.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, top_k=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_tokens_post_padded : top_k * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_tokens_post_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_tokens_post_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace
ck_tile
{
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
struct
FusedMoeGemmHostArgs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token
const
void
*
sorted_token_ids_ptr
;
const
void
*
sorted_weight_ptr
;
const
void
*
sorted_expert_ids_ptr
;
const
void
*
num_sorted_tiles_ptr
;
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n (TP slice this)
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
// index_t top_k; // need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// This is scatter/gather b2b group-gemm
template
<
typename
Partitioner_
,
typename
Pipeline_
,
typename
Epilogue_
>
struct
FusedMoeGemmKernel
{
using
Partitioner
=
remove_cvref_t
<
Partitioner_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
// TODO: not used
static
constexpr
index_t
kBlockSize
=
Pipeline
::
kBlockSize
;
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using
BlockShape
=
typename
Pipeline
::
BlockShape
;
// this is FusedMoeGemmShape
using
ADataType
=
typename
Pipeline
::
Problem
::
ADataType
;
using
GDataType
=
typename
Pipeline
::
Problem
::
GDataType
;
using
DDataType
=
typename
Pipeline
::
Problem
::
DDataType
;
using
AccDataType
=
typename
Pipeline
::
Problem
::
AccDataType
;
using
ODataType
=
typename
Pipeline
::
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Pipeline
::
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Pipeline
::
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Pipeline
::
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Pipeline
::
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Pipeline
::
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Pipeline
::
Problem
::
IndexDataType
;
using
YDataType
=
typename
Pipeline
::
Problem
::
YDataType
;
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
// clang-format on
}
struct
FusedMoeGemmKargs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token
const
void
*
sorted_token_ids_ptr
;
const
void
*
sorted_weight_ptr
;
const
void
*
sorted_expert_ids_ptr
;
const
void
*
num_sorted_tiles_ptr
;
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n (TP slice this)
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
// index_t top_k; // need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// TODO: switch karg based on
using
Kargs
=
FusedMoeGemmKargs
;
using
Hargs
=
FusedMoeGemmHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
// TODO: hargs/kargs not guranteed to be the same
return
bit_cast
<
Kargs
>
(
hargs
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
num_cu
,
index_t
blocks_per_cu
)
{
return
TilePartitioner
::
GridSize
(
num_cu
,
blocks_per_cu
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
Pipeline
::
GetSmemSize
(),
Epilogue
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
index_t
nr_0
=
kargs
.
intermediate_size
/
Pipeline
::
Block_Nr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
Pipeline
::
Block_Kr0
;
index_t
nr_1
=
kargs
.
hidden_size
/
Pipeline
::
Block_Nr1
;
// should be same as kr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
Pipeline
::
Block_Kr1
;
// should be same as nr_0
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
// note this is in unit of tile, need multiple tile size to get the index
const
auto
[
sorted_tile_id
,
hidden_tile_id
]
=
TilePartitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
return
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
// index along intermediate_size
index_t
hidden_idx
=
__builtin_amdgcn_readfirstlane
(
hidden_tile_id
*
BlockShape
::
Block_N0
);
index_t
hidden_idx_nr
=
__builtin_amdgcn_readfirstlane
(
hidden_tile_id
*
BlockShape
::
Block_Nr0
);
const
auto
a_coord
=
Pipeline
::
GetACoord
();
// 2d thread offset, [i_row, i_col]
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
sorted_tile_id
*
BlockShape
::
Block_M0
;
index_t
token_id
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_token_id
];
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
const
auto
a_window
=
[
&
]()
{
// A is already pre-padded in previous kernel
const
ADataType
*
a_ptr
=
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
auto
a_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
number
<
Pipeline
::
kAlignmentA
>
{},
number
<
1
>
{});
// gather is here use indexing transform
const
auto
a_gather_view_
=
transform_tensor_view
(
a_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
a_window_
=
make_tile_window
(
a_gather_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
Pipeline
::
Block_K0
>
{}),
{
0
,
0
});
return
a_window_
;
}();
// TODO: gtile using NSub to have less register pressure
const
auto
g_window
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
hidden_idx_nr
*
kr_0
*
BlockShape
::
Block_W0
;
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
Pipeline
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
Pipeline
::
Block_W0
>
{},
1
),
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
const
auto
g_view_1_
=
pad_tensor_view
(
g_view_
,
make_tuple
(
number
<
Pipeline
::
Block_Nr0
>
{},
number
<
Pipeline
::
Block_Kr0
>
{},
number
<
Pipeline
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
Pipeline
::
Block_Kr0
>
{},
number
<
Pipeline
::
Block_W0
>
{}),
{
0
,
0
,
0
});
return
g_window_
;
}();
const
auto
d_window
=
[
&
]()
{
const
DDataType
*
d_ptr
=
[
&
]()
{
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
hidden_idx_nr
*
BlockShape
::
Block_W1
;
// note hidden_idx_nr is along the gemm-k dim of 2nd gemm
}();
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
Pipeline
::
Block_W1
),
make_tuple
(
kr_1
*
Pipeline
::
Block_W1
,
Pipeline
::
Block_W1
,
1
),
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
make_tuple
(
number
<
Pipeline
::
kBlockNr_1
>
{},
number
<
Pipeline
::
kBlockKr_1
>
{},
number
<
Pipeline
::
Block_W1
>
{}),
sequence
<
PadHiddenSize
,
PadIntermediateSize
,
0
>
{});
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
Pipeline
::
kBlockNr_1
>
{},
number
<
Pipeline
::
kBlockKr_1
>
{},
number
<
Pipeline
::
Block_W1
>
{}),
{
0
,
0
,
0
});
return
d_window_
;
}();
auto
o_window
=
[
&
]()
{
const
ODataType
*
o_ptr
=
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
);
const
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
o_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
number
<
Pipeline
::
kAlignmentO
>
{},
number
<
1
>
{});
// gather is here
const
auto
o_scatter_view_
=
transform_tensor_view
(
o_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
o_window_
=
make_tile_window
(
o_scatter_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
Pipeline
::
Block_N1
>
{}),
{
0
,
0
});
return
o_window_
;
}();
// do compute yeah
Pipeline
{}(
a_window
,
g_window
,
d_window
,
o_window
,
topk_weight
,
smem
,
kargs
.
hidden_size
,
kargs
.
intermediate_size
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
tensors:
1. act (A): input feature map
2. gate (G): B matrix for first gemm, output will do activation(Silu)
3. up (U): B matrix for first gemm
4. down (D): B matrix for second gemm
N1
/ \
+----------+ |
| Down | |
x----------x |
hidden hidden K1 | | |
N0 N0 x----------x |
| +------x-----x------+------x-----x------+ | | |
dim | | Gate | | | Up | | | | | |
contiguous | | | | | | | | | | |
| | | | | | | | | | |
v +------x-----x------+------x-----x------+ +----------+ V
K0 | | | | | contiguous
/ \ v v v v |
+---------+ +------x-----x------+------x-----x------+ |
M0 | A | | | | | | | | |
+---------+ +------x-----x------+------x-----x------+ |
----------> | | |
contiguous | V V
| x-----x +----------+
+------------> M1 | Y | ---------> | Out(O) |
ACT x-----x +----------+
K1 = N0 dim
* Note: Act could be Gelu/Silu/...
* Note: some model does not have Up
*/
template
<
typename
BlockTile_0_
,
typename
WarpPerBlock_0_
,
typename
WarpTile_0_
,
typename
BlockTile_1_
,
typename
WarpPerBlock_1_
,
typename
WarpTile_1_
>
struct
FusedMoeGemmShape
{
using
BlockTile_0
=
remove_cvref_t
<
BlockTile_0_
>
;
using
WarpPerBlock_0
=
remove_cvref_t
<
WarpPerBlock_0_
>
;
using
WarpTile_0
=
remove_cvref_t
<
WarpTile_0_
>
;
using
BlockTile_1
=
remove_cvref_t
<
BlockTile_1_
>
;
using
WarpPerBlock_1
=
remove_cvref_t
<
WarpPerBlock_1_
>
;
using
WarpTile_1
=
remove_cvref_t
<
WarpTile_1_
>
;
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
WarpPerBlock_0
{},
multiplies
{},
number
<
1
>
{});
static_assert
(
NumWarps
==
reduce_on_sequence
(
WarpPerBlock_1
{},
multiplies
{},
number
<
1
>
{}));
static
constexpr
index_t
Block_M0
=
BlockTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N0
=
BlockTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_K0
=
BlockTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_M0
=
WarpPerBlock_0
::
at
(
numner
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N0
=
WarpPerBlock_0
::
at
(
numner
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_K0
=
WarpPerBlock_0
::
at
(
numner
<
2
>
{});
static
constexpr
index_t
Warp_M0
=
WarpTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N0
=
WarpTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_K0
=
WarpTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
ThreadPerBlock_M0
=
Warp_M0
*
WarpPerBlock_M0
;
static
constexpr
index_t
ThreadPerBlock_N0
=
Warp_N0
*
WarpPerBlock_N0
;
static
constexpr
index_t
ThreadPerBlock_K0
=
Warp_K0
*
WarpPerBlock_K0
;
static_assert
(
Block_M0
%
ThreadPerBlock_M0
==
0
);
static_assert
(
Block_N0
%
ThreadPerBlock_N0
==
0
);
static_assert
(
Block_K0
%
ThreadPerBlock_K0
==
0
);
static
constexpr
index_t
Repeat_M0
=
Block_M0
/
ThreadPerBlock_M0
;
static
constexpr
index_t
Repeat_N0
=
Block_N0
/
ThreadPerBlock_N0
;
static
constexpr
index_t
Repeat_K0
=
Block_K0
/
ThreadPerBlock_K0
;
static
constexpr
index_t
Block_M1
=
BlockTile_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N1
=
BlockTile_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_K1
=
BlockTile_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_M1
=
WarpTile_1
::
at
(
numner
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N1
=
WarpTile_1
::
at
(
numner
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_K1
=
WarpTile_1
::
at
(
numner
<
2
>
{});
static
constexpr
index_t
Warp_M1
=
WarpPerBlock_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N1
=
WarpPerBlock_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_K1
=
WarpPerBlock_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
ThreadPerBlock_M1
=
Warp_M1
*
WarpPerBlock_M1
;
static
constexpr
index_t
ThreadPerBlock_N1
=
Warp_N1
*
WarpPerBlock_N1
;
static
constexpr
index_t
ThreadPerBlock_K1
=
Warp_K1
*
WarpPerBlock_K1
;
static_assert
(
Block_M1
%
ThreadPerBlock_M1
==
0
);
static_assert
(
Block_N1
%
ThreadPerBlock_N1
==
0
);
static_assert
(
Block_K1
%
ThreadPerBlock_K1
==
0
);
static
constexpr
index_t
Repeat_M1
=
Block_M1
/
ThreadPerBlock_M1
;
static
constexpr
index_t
Repeat_N1
=
Block_N1
/
ThreadPerBlock_N1
;
static
constexpr
index_t
Repeat_K1
=
Block_K1
/
ThreadPerBlock_K1
;
static
constexpr
index_t
BlockSize
=
warpSize
*
NumWarps
;
// some assert
static_assert
(
Block_M0
==
Block_M1
);
static_assert
(
Block_N0
==
Block_K1
||
(
Block_N0
/
2
)
==
Block_K1
);
// Gate Only or Gate+Up
// pre-shuffle tile size compute (assume only for B matrix)
// we flatten the each wave tile to a 1d linear tensor(at model loading time)
// e.g. originally we have Block_N*Block_K tile size, after pre-shuffle
// we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K,
// and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K
static
constexpr
index_t
Block_W0
=
Warp_N0
*
Warp_K0
;
static
constexpr
index_t
Block_Nr0
=
Block_N0
/
Warp_N0
;
static
constexpr
index_t
Block_Kr0
=
Block_K0
/
Warp_K0
;
static
constexpr
index_t
Block_W1
=
Warp_N1
*
Warp_K1
;
static
constexpr
index_t
Block_Nr1
=
Block_N1
/
Warp_N1
;
static
constexpr
index_t
Block_Kr1
=
Block_K1
/
Warp_K1
;
static_assert
(
Block_W0
==
Block_W1
);
static_assert
(
Block_Nr0
==
Block_Kr1
);
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
template
<
typename
BlockShape_
>
struct
FusedMoeGemmTilePartitioner_Linear
{
// FusedMoeGemmShape
using
BlockShape
=
ck_tile
::
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
const
char
*
name
=
"eh"
;
// expert x hidden
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*num_sorted_tiles*/
,
ck_tile
::
index_t
/*hidden_size*/
))
{
index_t
i_n
=
blockIdx
.
x
;
index_t
i_m
=
blockIdx
.
y
;
return
ck_tile
::
make_tuple
(
i_m
,
i_n
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
max_tokens
,
index_t
hidden_size
)
{
// TODO: this may need tuning
index_t
ms
=
ck_tile
::
integer_divide_ceil
(
max_tokens
,
BlockShape
::
Block_M0
);
index_t
ns
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
BlockShape
::
Block_N0
);
return
dim3
(
ns
,
ms
,
1
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
0 → 100644
View file @
49c39b51
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
0 → 100644
View file @
49c39b51
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
0 → 100644
View file @
49c39b51
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// TODO: alow 2 gemm have different type
template
<
typename
ADataType_
,
typename
GDataType_
,
typename
DDataType_
,
typename
AccDataType_
,
typename
ODataType_
,
typename
AScaleDataType_
,
typename
W0ScaleDataType_
,
typename
W1ScaleDataType_
,
typename
YSmoothScaleDataType_
,
typename
TopkWeightDataType_
,
typename
IndexDataType_
,
// data type for all indexing
typename
GateActivation_
,
// = ck_tile::element_wise::Silu,
typename
BlockShape_
,
// shoule be FusedMoeGemmShape
typename
Traits_
>
struct
FusedMoeGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
GDataType
=
remove_cvref_t
<
GDataType_
>
;
using
DDataType
=
remove_cvref_t
<
DDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
AScaleDataType
=
remove_cvref_t
<
AScaleDataType_
>
;
using
GScaleDataType
=
remove_cvref_t
<
GScaleDataType_
>
;
using
DScaleDataType
=
remove_cvref_t
<
DScaleDataType_
>
;
using
YSmoothScaleDataType
=
remove_cvref_t
<
YSmoothScaleDataType_
>
;
using
TopkWeightDataType
=
remove_cvref_t
<
TopkWeightDataType_
>
;
using
IndexDataType
=
remove_cvref_t
<
IndexDataType_
>
;
// the input for next gemm should have same time as
using
YDataType
=
ADataType
;
using
GateActivation
=
remove_cvref_t
<
GateActivation_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
}
// namespace ck_tile
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment