Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6a03c66f
Commit
6a03c66f
authored
Nov 28, 2024
by
letaoqin
Browse files
start gemm down
parent
b2030e34
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
82 deletions
+68
-82
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+3
-5
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+36
-9
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+29
-68
No files found.
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
6a03c66f
...
@@ -319,13 +319,11 @@ struct FusedMoeGemmGlKernel
...
@@ -319,13 +319,11 @@ struct FusedMoeGemmGlKernel
const
auto
d_window
=
[
&
]()
{
const
auto
d_window
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
;
idx_n0
;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
d_ptr
,
make_tuple
(
kargs
.
hidden_size
,
BlockShape
::
Block_K1
),
make_tuple
(
kargs
.
hidden_size
,
kargs
.
intermediate_size
),
make_tuple
(
kargs
.
intermediate_size
,
1
),
make_tuple
(
kargs
.
intermediate_size
,
1
),
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -333,7 +331,7 @@ struct FusedMoeGemmGlKernel
...
@@ -333,7 +331,7 @@ struct FusedMoeGemmGlKernel
const
auto
d_window_
=
make_tile_window
(
const
auto
d_window_
=
make_tile_window
(
d_view_
,
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_N1
>
{},
number
<
BlockShape
::
Block_K1
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_N1
>
{},
number
<
BlockShape
::
Block_K1
>
{}),
{
0
,
0
});
{
0
,
idx_n
0
});
return
d_window_
;
return
d_window_
;
}();
}();
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
6a03c66f
...
@@ -99,8 +99,6 @@ struct FusedMoeGemmPipeline_General
...
@@ -99,8 +99,6 @@ struct FusedMoeGemmPipeline_General
index_t
intermediate_size
)
index_t
intermediate_size
)
{
{
ignore
=
d_window_
;
ignore
=
d_window_
;
ignore
=
o_window_
;
ignore
=
hidden_size
;
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
...
@@ -137,8 +135,8 @@ struct FusedMoeGemmPipeline_General
...
@@ -137,8 +135,8 @@ struct FusedMoeGemmPipeline_General
clear_tile
(
s_acc
);
// initialize C
clear_tile
(
s_acc
);
// initialize C
constexpr
index_t
kK0
=
BlockShape
::
Block_K0
;
constexpr
index_t
kK0
=
BlockShape
::
Block_K0
;
const
index_t
k0_loops
=
ck_tile
::
integer_divide_ceil
(
intermediate_size
,
kK0
);
const
index_t
k0_loops
=
ck_tile
::
integer_divide_ceil
(
intermediate_size
,
kK0
);
index_t
iCounter
=
k0_loops
-
1
;
index_t
iCounter
0
=
k0_loops
-
1
;
while
(
iCounter
>
0
)
while
(
iCounter
0
>
0
)
{
{
block_sync_lds
();
block_sync_lds
();
...
@@ -152,7 +150,7 @@ struct FusedMoeGemmPipeline_General
...
@@ -152,7 +150,7 @@ struct FusedMoeGemmPipeline_General
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
store_tile
(
a_lds_win
,
a_dram_block
);
store_tile
(
a_lds_win
,
a_dram_block
);
iCounter
--
;
iCounter
0
--
;
}
}
// tail
// tail
{
{
...
@@ -162,16 +160,45 @@ struct FusedMoeGemmPipeline_General
...
@@ -162,16 +160,45 @@ struct FusedMoeGemmPipeline_General
// move sacc to LDS
// move sacc to LDS
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
smem_0
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
auto
bridge_lds_win
=
auto
bridge_
s
lds_win
=
make_tile_window
(
bridge_lds_view
,
make_tile_window
(
bridge_lds_view
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>().
get_lengths
(),
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
});
{
0
,
0
});
auto
y_pre
=
cast_tile
<
YDataType
>
(
s_acc
);
auto
y_pre
=
cast_tile
<
YDataType
>
(
s_acc
);
store_tile
(
bridge_lds_win
,
y_pre
);
store_tile
(
bridge_slds_win
,
y_pre
);
// gemm1 down
block_sync_lds
();
// gemm down
constexpr
auto
gemm_1
=
Policy
::
template
GetBlockGemm1
<
Problem
>();
using
SaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
auto
o_acc
=
SaccBlockTileType
{};
// y data
auto
bridge_llds_win
=
make_tile_window
(
bridge_lds_view
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
},
Policy
::
template
MakeYTileDistribution
<
Problem
>());
auto
y
=
load_tile
(
bridge_llds_win
);
// d data
auto
d_global_to_dram_window
=
make_tile_window
(
d_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
d_window_
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>());
auto
d
=
load_tile
(
d_global_to_dram_window
);
constexpr
index_t
kN1
=
BlockShape
::
Block_N1
;
const
index_t
n1_loops
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
kN1
);
index_t
iCounter1
=
n1_loops
-
1
;
while
(
iCounter1
>
0
)
{
block_sync_lds
();
ignore
=
bridge_lds_win
;
iCounter1
--
;
}
ignore
=
y
;
ignore
=
d
;
store_tile
(
o_window_
,
a_dram_block
);
store_tile
(
o_window_
,
a_dram_block
);
#if 0
#if 0
//check a matrix gather right or not
//check a matrix gather right or not
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
6a03c66f
...
@@ -158,25 +158,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -158,25 +158,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
}
}
}
}
template
<
index_t
WarpPerBlock_N_
,
index_t
WarpPerBlock_K_
,
index_t
Repeat_N_
,
index_t
Repeat_K_
,
index_t
WarpSize_
,
index_t
Alignment_
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_Nr_Kr_W
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Repeat_N_
,
WarpPerBlock_N_
>
,
sequence
<
Repeat_K_
,
WarpPerBlock_K_
>
,
sequence
<
WarpSize_
,
Alignment_
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
3
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
0
>>
,
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_A
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_A
()
{
{
...
@@ -231,17 +212,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -231,17 +212,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_D
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_D
()
{
{
constexpr
auto
PermuteEnum
=
Problem
::
Traits
::
PermuteEnum
;
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
S_
=
typename
Problem
::
BlockShape
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
if
constexpr
(
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
constexpr
auto
d_outer_dstr_enc
=
return
MakeGlobalTileDistribution_Nr_Kr_W
<
S_
::
WarpPerBlock_N1
,
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_N1
>
,
S_
::
WarpPerBlock_K1
,
tuple
<
sequence
<
S_
::
Repeat_N1
>
,
sequence
<
S_
::
Repeat_K1
>>
,
S_
::
Repeat_N1
,
tuple
<
sequence
<
0
>>
,
S_
::
Repeat_K1
,
tuple
<
sequence
<
0
>>
,
get_warp_size
(),
sequence
<
1
,
2
>
,
GetAlignment_D
<
Problem
>
()
>
();
sequence
<
0
,
0
>>
{};
}
constexpr
auto
d_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
d_outer_dstr_enc
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
d_block_dstr
=
make_static_tile_distribution
(
d_block_dstr_encode
);
return
d_block_dstr
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -375,50 +360,26 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -375,50 +360,26 @@ struct FusedMoeGemmPipelineGeneralPolicy
}
}
}
}
// this is used as A matrix for 2nd gemm
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
CBlockTile_Gemm0
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
YTileDistribution
()
{
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm0
<
Problem
>
())
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
using
CDataType
=
typename
WarpGemm
::
CDataType
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S_
::
Repeat_M0
,
S_
::
WarpPerBlock_M0
>
,
sequence
<
S_
::
Repeat_N0
,
S_
::
WarpPerBlock_N0
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeCBlockTile_Gemm1
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
using
CDataType
=
typename
WarpGemm
::
CDataType
;
constexpr
auto
c_block_outer_dstr_encoding
=
// TODO: all waves a along different N, but same M
tile
_d
i
str
ibution_encoding
<
sequence
<>
,
constexpr
auto
y_outer
_dstr
_enc
=
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
Repeat_
N1
,
S_
::
WarpPerBlock_N
1
>>
,
tuple
<
sequence
<
S_
::
Repeat_
M1
>
,
sequence
<
S_
::
Repeat_K
1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>>
,
tuple
<
sequence
<
1
,
1
>>
,
tuple
<
sequence
<
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
auto
y_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
y_outer_dstr_enc
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
constexpr
auto
y_block_dstr
=
make_static_tile_distribution
(
y_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
y_block_dstr
;
return
c_block_tensor
;
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment