Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
811b75d3
Commit
811b75d3
authored
Dec 31, 2024
by
shengnxu
Browse files
staging code for backup
parent
1c9a5ff4
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
3227 additions
and
2 deletions
+3227
-2
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+6
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
...ile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
+2
-1
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
...ps/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
+757
-0
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
...flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
+565
-0
include/ck_tile/ops/flatmm/block/flatmm_uk_config int8.hpp
include/ck_tile/ops/flatmm/block/flatmm_uk_config int8.hpp
+10
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8.inc
...k/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8.inc
+613
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
...lock/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
+797
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+1
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
+405
-0
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+5
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+66
-0
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
View file @
811b75d3
...
@@ -28,6 +28,12 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
...
@@ -28,6 +28,12 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
else
if
(
t
.
prec_i
==
"int8"
&&
t
.
prec_w
==
"int8"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
256
,
256
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
64
>
,
1
,
1
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
// clang-format on
// clang-format on
return
r
;
return
r
;
}
}
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
View file @
811b75d3
...
@@ -38,7 +38,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
...
@@ -38,7 +38,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
f_traits
>
;
f_traits
>
;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmUk
<
f_problem
>
;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmUk_int8
<
>
;
using
f_partitioner
=
ck_tile
::
FusedMoeGemmTilePartitioner_Linear
<
f_shape
>
;
using
f_partitioner
=
ck_tile
::
FusedMoeGemmTilePartitioner_Linear
<
f_shape
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemmKernel
<
f_partitioner
,
f_pipeline
,
void
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemmKernel
<
f_partitioner
,
f_pipeline
,
void
>
;
...
...
include/ck_tile/ops/flatmm/block/flatmm_32x512x256_1x4x1_16x16x64_int8.hpp
0 → 100644
View file @
811b75d3
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp
0 → 100644
View file @
811b75d3
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/flatmm_uk_config int8.hpp
0 → 100644
View file @
811b75d3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#define CK_TILE_FLATMM_UK_MFMA_FP16 0
#define CK_TILE_FLATMM_UK_MFMA_BF16 1
#define CK_TILE_FLATMM_UK_MFMA_INT8 2
#define CK_TILE_FLATMM_UK_MFMA_FP8 3
#define CK_TILE_FLATMM_UK_MFMA_BF8 4
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8.inc
0 → 100644
View file @
811b75d3
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x256_1x1x1_16x16x32_int8.inc
0 → 100644
View file @
811b75d3
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
811b75d3
...
@@ -310,7 +310,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -310,7 +310,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0
(
a_res
,
auto
acc_0
=
uk_0
(
a_res
,
a_coords
,
a_coords
,
g_res
,
g_res
,
g_coords
,
g_coords
,
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp
0 → 100644
View file @
811b75d3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace
ck_tile
{
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template
<
typename
Problem_
,
typename
Policy_
=
FusedMoeGemmPipelineFlatmmPolicy
>
struct
FusedMoeGemmPipeline_FlatmmUk_int8
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
// this is FusedMoeGemmShape
using
ADataType
=
typename
Problem
::
ADataType
;
using
GDataType
=
typename
Problem
::
GDataType
;
using
DDataType
=
typename
Problem
::
DDataType
;
using
AccDataType
=
typename
Problem
::
AccDataType
;
using
ODataType
=
typename
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Problem
::
IndexDataType
;
using
YDataType
=
typename
Problem
::
YDataType
;
using
Traits
=
typename
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
index_t
kAlignmentA
=
Policy
::
template
GetAlignment_A
<
Problem
>();
static
constexpr
index_t
kAlignmentG
=
Policy
::
template
GetAlignment_G
<
Problem
>();
static
constexpr
index_t
kAlignmentD
=
Policy
::
template
GetAlignment_D
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignment_O
<
Problem
>();
static
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
static
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
static
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
static
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
return
2
;
}
}();
static
constexpr
const
char
*
name
=
"flatmm_uk"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOCoord
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
CK_TILE_DEVICE
constexpr
auto
GetNumRowCoords_A
()
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
return
MRepeat
;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE
auto
GetRowCoords_A
(
index_t
base_offset
)
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
auto
base_coord
=
threadIdx
.
x
/
KLans
+
base_offset
;
array
<
index_t
,
MRepeat
>
coords
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLans
;
});
return
coords
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetRowID
(
const
ROW_COORDS
coords
,
const
IndexDataType
*
sorted_token_ids_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
index_t
,
n_size
>
row_ids
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
row_ids
.
at
(
i
)
=
sorted_token_ids_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
row_ids
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetWeightScale
(
const
ROW_COORDS
coords
,
const
TopkWeightDataType
*
sorted_weight_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
TopkWeightDataType
,
n_size
>
w
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
w
.
at
(
i
)
=
sorted_weight_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
w
;
}
// TODO: this row id is before shuffle atomic, need use acc distribution
CK_TILE_DEVICE
auto
GetRowCoords_O
(
index_t
base_offset
)
{
constexpr
index_t
MLanes
=
BlockShape
::
Warp_M1
;
constexpr
index_t
Repeat_M
=
BlockShape
::
Repeat_M1
;
auto
base_coord
=
threadIdx
.
x
%
MLanes
+
base_offset
;
array
<
index_t
,
Repeat_M
>
coords
;
static_for
<
0
,
Repeat_M
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLanes
;
});
return
coords
;
}
template
<
typename
Karg
>
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
sorted_tile_id
,
index_t
intermediate_tile_id
)
{
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
;
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Warp_K0
;
// divide K in W
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Warp_N1
;
index_t
kr_1
=
shared_intermediate_size_1
/
BlockShape
::
Warp_K1
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
/////////////
index_t
a_scale_expert_stride_0
=
kargs
.
hidden_size
;
index_t
g_scale_expert_stride_0
=
shared_intermediate_size_0
;
index_t
d_scale_expert_stride_1
=
kargs
.
hidden_size
;
// nr*kr*w
index_t
interm_idx_nr0
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
// intermediate_tile_id * Block_N / (N in W)
index_t
interm_idx_kr1
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Kr1
);
// intermediate_tile_id * Block_N / (N in W)
auto
row_coords_a
=
GetRowCoords_A
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_ids_a
=
GetRowID
(
row_coords_a
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
auto
token_id
=
row_ids_a
&
0xffffff
;
//addr in fact
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
(
token_id
)
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
//////aq
auto
aq_win
=
[
&
]()
{
const
AScaleDataType
*
aq_ptr
=
reinterpret_cast
<
const
AScaleDataType
*>
(
kargs
.
a_scale_ptr
);
auto
aq_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
aq_ptr
,
make_tuple
(
kargs
.
num_tokens
*
kargs
.
topk
),
number
<
1
>
{});
return
aq_view_
;
}();
auto
aq_res
=
aq_win
.
get_buffer_view
().
cached_buf_res_
;
////////
auto
g_win
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr0
*
kr_0
*
BlockShape
::
Block_W0
;
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
1
>
{});
auto
g_window_
=
make_tile_window_linear_raw
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
g_window_
;
}();
auto
g_res
=
g_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
g_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
g_win
.
cached_coords_
[
i
].
get_offset
();
},
number
<
decltype
(
g_win
)
::
NumAccess_NonLinear
>
{});
//////gq
auto
gq_win
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GScaleDataType
*>
(
kargs
.
g_scale_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
g_scale_expert_stride_0
+
intermediate_tile_id
*
BlockShape
::
Block_N0
;
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
shared_intermediate_size_1
),
number
<
1
>
{});
return
g_view_
;
}();
auto
gq_res
=
gq_win
.
get_buffer_view
().
cached_buf_res_
;
////
const
auto
d_win
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
interm_idx_kr1
*
BlockShape
::
Block_W1
;
// note interm_idx_nr0 is along the gemm-k dim of 2nd gemm
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_W1
),
make_tuple
(
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
number
<
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_window_
=
make_tile_window_linear_raw
(
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
d_window_
;
}();
auto
d_res
=
d_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
//////gq
auto
dq_win
=
[
&
]()
{
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr) + static_cast<long_index_t>(expert_id) * d_scale_expert_stride_0;
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GScaleDataType
*>
(
kargs
.
d_scale_ptr
)
//remember to add expert_id as expert_idx
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
kargs
.
hidden_size
),
number
<
1
>
{});
return
g_view_
;
}();
auto
dq_res
=
dq_win
.
get_buffer_view
().
cached_buf_res_
;
////
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
// block-k=512, block-n=128
// wg |<----- W_ ----->|
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
// y p y y p p y
// 1 2 0(imm)
auto
d_coords
=
[
&
]()
{
constexpr
index_t
Nr_
=
4
;
constexpr
index_t
Nw_
=
4
;
constexpr
index_t
Kr0_
=
2
;
//no more need in int8, method changed, this will be handed in res_s
constexpr
index_t
Kr1_
=
4
;
constexpr
index_t
Kl_
=
4
;
constexpr
index_t
Nl_
=
16
;
constexpr
index_t
Kv_
=
16
;
constexpr
index_t
W_
=
Kl_
*
Nl_
*
Kv_
;
//constexpr index_t num_offsets_ = Nr_ * Kr0_;
constexpr
index_t
num_offsets_
=
Nr_
;
index_t
base_os_
=
(
threadIdx
.
x
%
64
)
*
Kv_
+
(
threadIdx
.
x
/
64
)
*
shared_intermediate_size_1
*
Nl_
;
// Kr0_ * Kr1_ * W_;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
i_nr_
=
number
<
i
%
Nr_
>
{};
return
i_nr_
*
shared_intermediate_size_1
*
Nw_
*
Nl_
+
base_os_
;
},
number
<
num_offsets_
>
{});
}();
auto
o_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
token_id
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
kAlignmentO
)
*
kAlignmentO
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
o_flags
=
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
token_id
,
kargs
.
num_tokens
);
},
number
<
row_ids_a
.
size
()
>
{});
auto
bridge_sst_win
=
[
&
]()
{
constexpr
auto
desc_
=
Policy
::
template
MakeBridgeLdsStoreForUKDesc
<
Problem
>();
constexpr
auto
dist_
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockDist
();
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
desc_
),
desc_
.
get_lengths
(),
{
0
,
0
},
dist_
);
}();
auto
o_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ODataType
));
auto
row_coords_o
=
GetRowCoords_O
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0
(
row_ids_a
,
//fake token id, 2D index for X scale
aq_res
,
gq_res
,
dq_res
,
a_res
,
a_coords
,
g_res
,
g_coords
,
smem
,
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
// sweep_tile(
// acc_0,
// [&](auto idx0, auto idx1) {
// fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
// typename Problem::GateActivation{}(v_, v_);
// acc_0(idx0) = v_.x;
// acc_0(idx1) = v_.y;
// },
// sequence<1, 2>{});
// auto y_pre = cast_tile<YDataType>(acc_0);
// block_sync_lds();
// store_tile(bridge_sst_win, y_pre);
// block_sync_lds();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
uk_1
(
d_res
,
d_coords
,
o_res
,
o_coords
,
o_flags
,
smem
,
kargs
.
hidden_size
,
// total n number
w_scale
,
shared_intermediate_size_1
*
Block_N1
-
kr_1
*
BlockShape
::
Block_W1
,
// along N
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_N1
);
// along N
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
811b75d3
...
@@ -140,5 +140,10 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
...
@@ -140,5 +140,10 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
WGAttrCtlEnum
::
Default_
>
,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
WGAttrCtlEnum
::
Default_
>
,
2
,
2
,
swizzle_factor
>>
;
swizzle_factor
>>
;
// int8
using
WarpGemmMfma_i32_16x16x64_int8_int8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
811b75d3
...
@@ -618,6 +618,72 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
...
@@ -618,6 +618,72 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
return
c_vec
;
return
c_vec
;
}
}
};
};
template
<
WGAttrCtlEnum
Ctrl_
=
WGAttrCtlEnum
::
Default_
>
struct
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
{
static
constexpr
WGAttrCtlEnum
Ctrl
=
Ctrl_
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int32_t
;
using
AVecType
=
ext_vector_t
<
ADataType
,
8
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
8
>
;
using
CVecType
=
ext_vector_t
<
CDataType
,
4
>
;
static
constexpr
index_t
kM
=
16
;
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
32
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
static
constexpr
index_t
kABKPerLane
=
8
;
static
constexpr
index_t
kCMLane
=
4
;
static
constexpr
index_t
kCNLane
=
16
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
template
<
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
bool_constant
<
post_nop_
>
=
{})
const
{
DISPATCH_MFMA_CTRL_
(
"v_mfma_i32_16x16x32_i8"
,
Ctrl
)
else
{
#if defined(__gfx94__)
c_vec
=
__builtin_amdgcn_mfma_i32_16x16x32i8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
a_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
float
b_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
CVecType
c_vec
{
0
};
operator
()(
c_vec
,
a_vec
,
b_vec
);
return
c_vec
;
}
};
#undef DISPATCH_MFMA_
#undef DISPATCH_MFMA_
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment