Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7dc420a8
Commit
7dc420a8
authored
Feb 12, 2025
by
ThomasNing
Browse files
Solve merge conflict and add the gtest for compv4
parents
884a2f7c
ef2b53a9
Changes
68
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
888 additions
and
89 deletions
+888
-89
CMakeLists.txt
CMakeLists.txt
+1
-0
Jenkinsfile
Jenkinsfile
+3
-3
codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
...de/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
+61
-0
codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
...lude/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
+47
-0
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
+2
-0
codegen/include/ck/host/operation/gemm.hpp
codegen/include/ck/host/operation/gemm.hpp
+20
-0
codegen/include/ck/host/types.hpp
codegen/include/ck/host/types.hpp
+15
-0
codegen/src/device_batched_gemm_softmax_gemm.cpp
codegen/src/device_batched_gemm_softmax_gemm.cpp
+38
-0
codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
...vice_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
+408
-0
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
...gen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
+67
-35
codegen/src/types.cpp
codegen/src/types.cpp
+20
-0
codegen/test/rtc/include/rtc/hip.hpp
codegen/test/rtc/include/rtc/hip.hpp
+1
-0
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+9
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+8
-1
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+2
-1
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+5
-2
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+2
-2
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+40
-38
example/ck_tile/13_moe_sorting/moe_sorting.cpp
example/ck_tile/13_moe_sorting/moe_sorting.cpp
+57
-6
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+82
-0
No files found.
CMakeLists.txt
View file @
7dc420a8
...
@@ -92,6 +92,7 @@ endif()
...
@@ -92,6 +92,7 @@ endif()
add_compile_options
(
-Wno-bit-int-extension
)
add_compile_options
(
-Wno-bit-int-extension
)
add_compile_options
(
-Wno-pass-failed
)
add_compile_options
(
-Wno-pass-failed
)
add_compile_options
(
-Wno-switch-default
)
add_compile_options
(
-Wno-switch-default
)
add_compile_options
(
-Wno-unique-object-duplication
)
if
(
DL_KERNELS
)
if
(
DL_KERNELS
)
add_definitions
(
-DDL_KERNELS
)
add_definitions
(
-DDL_KERNELS
)
...
...
Jenkinsfile
View file @
7dc420a8
...
@@ -117,7 +117,7 @@ def getDockerImage(Map conf=[:]){
...
@@ -117,7 +117,7 @@ def getDockerImage(Map conf=[:]){
{
{
echo
"Pulling down image: ${image}"
echo
"Pulling down image: ${image}"
retimage
=
docker
.
image
(
"${image}"
)
retimage
=
docker
.
image
(
"${image}"
)
withDockerRegistry
([
credentialsId:
"docker_
test_
cred"
,
url:
""
])
{
withDockerRegistry
([
credentialsId:
"
ck_
docker_cred"
,
url:
""
])
{
retimage
.
pull
()
retimage
.
pull
()
}
}
}
}
...
@@ -148,7 +148,7 @@ def buildDocker(install_prefix){
...
@@ -148,7 +148,7 @@ def buildDocker(install_prefix){
//force building the new docker if that parameter is true
//force building the new docker if that parameter is true
echo
"Building image: ${image_name}"
echo
"Building image: ${image_name}"
retimage
=
docker
.
build
(
"${image_name}"
,
dockerArgs
)
retimage
=
docker
.
build
(
"${image_name}"
,
dockerArgs
)
withDockerRegistry
([
credentialsId:
"docker_
test_
cred"
,
url:
""
])
{
withDockerRegistry
([
credentialsId:
"
ck_
docker_cred"
,
url:
""
])
{
retimage
.
push
()
retimage
.
push
()
}
}
sh
'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi'
sh
'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi'
...
@@ -162,7 +162,7 @@ def buildDocker(install_prefix){
...
@@ -162,7 +162,7 @@ def buildDocker(install_prefix){
catch
(
Exception
ex
){
catch
(
Exception
ex
){
echo
"Unable to locate image: ${image_name}. Building image now"
echo
"Unable to locate image: ${image_name}. Building image now"
retimage
=
docker
.
build
(
"${image_name}"
,
dockerArgs
+
' .'
)
retimage
=
docker
.
build
(
"${image_name}"
,
dockerArgs
+
' .'
)
withDockerRegistry
([
credentialsId:
"docker_
test_
cred"
,
url:
""
])
{
withDockerRegistry
([
credentialsId:
"
ck_
docker_cred"
,
url:
""
])
{
retimage
.
push
()
retimage
.
push
()
}
}
}
}
...
...
codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
0 → 100644
View file @
7dc420a8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// defines all values need for an instance of fwd conv
struct
Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
CreateOperations
(
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
// returns a vector of instances, given a problem spec and fusion operators
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
TensorDesc
A
{};
TensorDesc
B
{};
TensorDesc
B1
{};
TensorDesc
C
{};
DataType
acc
=
DataType
::
Float
;
DataType
cs_type
=
DataType
::
Half
;
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
b_elem_op
=
PassThrough
;
std
::
string
b1_elem_op
=
PassThrough
;
std
::
string
c_elem_op
=
PassThrough
;
std
::
string
acc_elem_op
=
Scale
;
std
::
string
prologue
=
""
;
std
::
string
epilogue
=
""
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default"
;
// tuning parameters
operation
::
TileDescGemmGemm
tile_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
BlockTransferDesc
b0_block_transfer
{};
operation
::
BlockTransferDesc
b1_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
bool
mask_out_upper_triangle
=
false
;
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
,
std
::
size_t
Gemm1NRaw_
);
// returns a templated instance
Solution
ToSolution
()
const
;
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
0 → 100644
View file @
7dc420a8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// defines the problem specification for a GEMM operation
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
PassThrough
;
std
::
string
BElementOp
=
PassThrough
;
std
::
string
B1ElementOp
=
PassThrough
;
std
::
string
CElementOp
=
PassThrough
;
std
::
string
AccElementOp
=
Scale
;
// returns the correct device op file for the operation
std
::
string
GetIncludeHeader
()
const
;
// returns a list of instances based on the problem spec and provided fusion operations
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
;
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
View file @
7dc420a8
...
@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
...
@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
LoopScheduler
loop_scheduler
{};
PipelineVersion
pipeline_version
{};
// functions to update fusion operators if provided
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_prologue
(
const
std
::
string
&
prologue
);
...
...
codegen/include/ck/host/operation/gemm.hpp
View file @
7dc420a8
...
@@ -23,6 +23,26 @@ struct TileDesc
...
@@ -23,6 +23,26 @@ struct TileDesc
int
n_Xdl_per_wave
=
0
;
int
n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
};
struct
TileDescGemmGemm
{
int
block_size
=
0
;
int
gemm01_m_per_block
=
0
;
int
gemm0_n_per_block
=
0
;
int
gemm0_k_per_block
=
0
;
int
gemm1_n_per_block
=
0
;
int
gemm1_k_per_block
=
0
;
int
ak1
=
0
;
int
bk1
=
0
;
int
b1k1
=
0
;
int
m_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
gemm0_m_Xdl_per_wave
=
0
;
int
gemm0_n_Xdl_per_wave
=
0
;
int
gemm1_n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
struct
BlockTransferDesc
struct
BlockTransferDesc
{
{
std
::
string
thread_cluster_length
=
""
;
std
::
string
thread_cluster_length
=
""
;
...
...
codegen/include/ck/host/types.hpp
View file @
7dc420a8
...
@@ -66,6 +66,20 @@ enum class GemmType
...
@@ -66,6 +66,20 @@ enum class GemmType
};
};
std
::
string
ToString
(
GemmType
gt
);
std
::
string
ToString
(
GemmType
gt
);
enum
class
LoopScheduler
{
Default
,
Interwave
,
};
std
::
string
ToString
(
LoopScheduler
ls
);
enum
class
PipelineVersion
{
v1
,
v2
};
std
::
string
ToString
(
PipelineVersion
pv
);
struct
TensorDesc
struct
TensorDesc
{
{
DataType
element
;
DataType
element
;
...
@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
...
@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
constexpr
const
char
*
PassThrough
=
"ck::tensor_operation::element_wise::PassThrough"
;
constexpr
const
char
*
PassThrough
=
"ck::tensor_operation::element_wise::PassThrough"
;
constexpr
const
char
*
Bilinear
=
"ck::tensor_operation::element_wise::Bilinear"
;
constexpr
const
char
*
Bilinear
=
"ck::tensor_operation::element_wise::Bilinear"
;
constexpr
const
char
*
Scale
=
"ck::tensor_operation::element_wise::Scale"
;
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
codegen/src/device_batched_gemm_softmax_gemm.cpp
0 → 100644
View file @
7dc420a8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// return the relevant device op file based on the operation
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
"ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
;
}
// returns templated instances when provided with a problem specification
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
auto
ops
=
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
// template instance with correct values
});
return
result
;
}
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
0 → 100644
View file @
7dc420a8
This diff is collapsed.
Click to expand it.
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
View file @
7dc420a8
...
@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
...
@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major
// accounts for all possible combinations of Row/Col major
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
// clang-format off
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// clang-format on
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
// instances
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
...
@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
1
},
{
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
1
},
{
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
1
},
{
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
1
},
{
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
1
},
{
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
1
},
// Irregular tile
{
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
// | | | | | | |
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// clang-format on
// clang-format on
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
};
};
std
::
vector
<
operation
::
BlockTransferDesc
>
b_block_descriptions_rowmajor
=
{
std
::
vector
<
operation
::
BlockTransferDesc
>
b_block_descriptions_rowmajor
=
{
...
@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
1
,
16
,
1
,
8
>
,
8
},
{
S
<
1
,
16
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
// Irregular tile
{
S
<
1
,
16
,
1
,
4
>
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -199,33 +218,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -199,33 +218,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
// Put all values together into a single operation > store into the result vector
const
std
::
vector
<
std
::
tuple
<
LoopScheduler
,
PipelineVersion
>>
scheduler_pipeline_descriptions
=
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
{
{
LoopScheduler
::
Default
,
PipelineVersion
::
v1
},
{
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
},
{
LoopScheduler
::
Default
,
PipelineVersion
::
v2
},
};
for
(
auto
[
loop_scheduler
,
pipeline_version
]
:
scheduler_pipeline_descriptions
)
{
{
Operation_Xdl_CShuffle
x
;
// Put all values together into a single operation > store into the result vector
x
.
tile_desc
=
tile_descriptions
[
i
];
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
x
.
a_block_transfer
=
a_block_descriptions
[
i
];
{
x
.
b_block_transfer
=
b_block_descriptions
[
i
];
Operation_Xdl_CShuffle
x
;
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
tile_desc
=
tile_descriptions
[
i
];
x
.
c_block_transfer
=
c_block_descriptions
[
i
];
x
.
a_block_transfer
=
a_block_descriptions
[
i
];
x
.
A
=
TensorDesc
{
prob
.
ADataType
,
ToLayout
(
prob
.
TransA
)};
x
.
b_block_transfer
=
b_block_descriptions
[
i
];
x
.
B
=
TensorDesc
{
prob
.
BDataType
,
ToLayout
(
prob
.
TransB
)};
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
E
=
TensorDesc
{
prob
.
EDataType
,
ToLayout
(
prob
.
TransE
)};
x
.
c_block_transfer
=
c_block_descriptions
[
i
];
x
.
Ds
=
Transform
(
prob
.
DsTrans
,
prob
.
DsDataType
,
[](
auto
trans
,
auto
dt
)
{
x
.
A
=
TensorDesc
{
prob
.
ADataType
,
ToLayout
(
prob
.
TransA
)};
return
TensorDesc
{
dt
,
ToLayout
(
trans
)};
x
.
B
=
TensorDesc
{
prob
.
BDataType
,
ToLayout
(
prob
.
TransB
)};
});
x
.
E
=
TensorDesc
{
prob
.
EDataType
,
ToLayout
(
prob
.
TransE
)};
x
.
a_elem_op
=
prob
.
AElementOp
;
x
.
Ds
=
Transform
(
prob
.
DsTrans
,
prob
.
DsDataType
,
[](
auto
trans
,
auto
dt
)
{
x
.
b_elem_op
=
prob
.
BElementOp
;
return
TensorDesc
{
dt
,
ToLayout
(
trans
)};
x
.
cde_elem_op
=
prob
.
CDEElementOp
;
});
x
.
gemm_specialization
=
GetGemmSpec
(
prob
.
M
,
x
.
a_elem_op
=
prob
.
AElementOp
;
prob
.
N
,
x
.
b_elem_op
=
prob
.
BElementOp
;
prob
.
K
,
x
.
cde_elem_op
=
prob
.
CDEElementOp
;
x
.
tile_desc
.
m_per_block
,
x
.
gemm_specialization
=
GetGemmSpec
(
prob
.
M
,
x
.
tile_desc
.
n_per_block
,
prob
.
N
,
x
.
tile_desc
.
k_per_block
);
prob
.
K
,
x
.
update_prologue
(
prologue
);
x
.
tile_desc
.
m_per_block
,
x
.
update_epilogue
(
epilogue
);
x
.
tile_desc
.
n_per_block
,
result
.
push_back
(
x
);
x
.
tile_desc
.
k_per_block
);
x
.
loop_scheduler
=
loop_scheduler
;
x
.
pipeline_version
=
pipeline_version
;
x
.
update_prologue
(
prologue
);
x
.
update_epilogue
(
epilogue
);
result
.
push_back
(
x
);
}
}
}
return
result
;
return
result
;
}
}
...
@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
...
@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}>"
;
"${CDEBlockTransferScalarPerVector_NPerBlock}
, ${LoopScheduler}, ${PipelineVersion}
>"
;
// use hardcoded instances from vector of operations to substitute values into instance template
// use hardcoded instances from vector of operations to substitute values into instance template
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
...
@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
this
->
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
},
this
->
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
},
{
"CDEBlockTransferScalarPerVector_NPerBlock"
,
{
"CDEBlockTransferScalarPerVector_NPerBlock"
,
std
::
to_string
(
this
->
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
)},
std
::
to_string
(
this
->
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
)},
{
"LoopScheduler"
,
ToString
(
this
->
loop_scheduler
)},
{
"PipelineVersion"
,
ToString
(
this
->
pipeline_version
)},
};
};
return
Solution
{
InterpolateString
(
DeviceGemmMultipleD_Xdl_CShuffleTemplate
,
values
),
return
Solution
{
InterpolateString
(
DeviceGemmMultipleD_Xdl_CShuffleTemplate
,
values
),
...
...
codegen/src/types.cpp
View file @
7dc420a8
...
@@ -59,6 +59,26 @@ std::string ToString(GemmType gt)
...
@@ -59,6 +59,26 @@ std::string ToString(GemmType gt)
throw
std
::
runtime_error
(
"Incorrect gemm type"
);
throw
std
::
runtime_error
(
"Incorrect gemm type"
);
}
}
std
::
string
ToString
(
LoopScheduler
ls
)
{
switch
(
ls
)
{
case
LoopScheduler
::
Default
:
return
"ck::LoopScheduler::Default"
;
case
LoopScheduler
::
Interwave
:
return
"ck::LoopScheduler::Interwave"
;
}
throw
std
::
runtime_error
(
"Incorrect LoopScheduler type"
);
}
std
::
string
ToString
(
PipelineVersion
pv
)
{
switch
(
pv
)
{
case
PipelineVersion
::
v1
:
return
"ck::PipelineVersion::v1"
;
case
PipelineVersion
::
v2
:
return
"ck::PipelineVersion::v2"
;
}
throw
std
::
runtime_error
(
"Incorrect PipelineVersion type"
);
}
std
::
string
SequenceStr
(
const
std
::
vector
<
int
>&
v
)
std
::
string
SequenceStr
(
const
std
::
vector
<
int
>&
v
)
{
{
return
"ck::Sequence<"
+
return
"ck::Sequence<"
+
...
...
codegen/test/rtc/include/rtc/hip.hpp
View file @
7dc420a8
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <memory>
#include <memory>
#include <stdexcept>
#include <stdexcept>
#include <string>
#include <string>
#include <stdexcept>
namespace
rtc
{
namespace
rtc
{
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
7dc420a8
...
@@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond
&=
deterministic
==
"f"
cond
&=
deterministic
==
"f"
if
not
cond
:
if
not
cond
:
continue
continue
if
receipt
==
4
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
bias
in
[
'no'
,
'bias'
]
cond
&=
dropout
in
[
'no'
,
'dropout_wg32'
,
'dropout_wg16'
]
cond
&=
dpad
==
dvpad
cond
&=
deterministic
==
"f"
if
not
cond
:
continue
api_pool
.
register_dq_dk_dv_traits
(
k
.
api_trait
())
api_pool
.
register_dq_dk_dv_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
gen
.
append
(
k
)
...
@@ -801,4 +809,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
...
@@ -801,4 +809,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
_
,
kernels
=
get_bwd_dq_dk_dv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_BWD_API_FILENAME
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_BWD_API_FILENAME
)
+
"
\n
"
)
\ No newline at end of file
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
7dc420a8
...
@@ -487,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
...
@@ -487,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
if
kernel_filter
!=
None
:
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
continue
if
receipt
==
2
:
if
receipt
in
(
2
,
3
)
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'alibi'
]
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'alibi'
]
cond
&=
pipeline
.
F_squant
==
'f'
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
if
not
cond
:
continue
continue
if
receipt
==
4
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'bias'
]
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
continue
api_pool
.
register_traits
(
k
.
api_trait
())
api_pool
.
register_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
gen
.
append
(
k
)
...
...
example/ck_tile/01_fmha/generate.py
View file @
7dc420a8
...
@@ -103,7 +103,8 @@ if __name__ == "__main__":
...
@@ -103,7 +103,8 @@ if __name__ == "__main__":
required
=
False
,
required
=
False
,
help
=
"codegen receipt. 0: generate only 8xhdim coverage
\n
"
+
\
help
=
"codegen receipt. 0: generate only 8xhdim coverage
\n
"
+
\
" 1: generate more instance to cover all hdim
\n
"
+
\
" 1: generate more instance to cover all hdim
\n
"
+
\
" 2: Only generate instance for Flash attention integration"
" 2: Only generate instance for Flash attention integration
\n
"
+
\
" 4: Only generate instance for PyTorch integration"
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
7dc420a8
...
@@ -82,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -82,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
{
{
std
::
cout
<<
"Launching kernel with args:"
std
::
cout
<<
"Launching kernel with args: "
<<
Kernel
::
GetName
()
<<
'\n'
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
"shape: "
<<
CodegenGemmShape
::
GetName
()
<<
'\n'
<<
"problem: "
<<
CodegenPipelineProblem
::
GetName
()
<<
'\n'
<<
"pipeline: "
<<
CodegenGemmPipeline
::
GetName
()
<<
'\n'
<<
"grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
<<
std
::
endl
;
}
}
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
7dc420a8
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V
3
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V
4
#endif
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
7dc420a8
...
@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
...
@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
template
<
typename
ADataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
...
@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args
.
stride_B
=
stride_B
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
args
.
stride_C
=
stride_C
;
float
ave_time
=
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
float
ave_time
=
ALayout
,
BLayout
,
CLayout
>
(
gemm_calc
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_byte
=
std
::
size_t
num_byte
=
...
@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
std
::
cout
<<
"Run Gemm kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" StrideA ="
<<
stride_A
<<
" StrideB ="
<<
stride_B
<<
" StrideC ="
<<
stride_C
<<
" A_Layout ="
<<
ALayout
::
name
<<
" A_Layout ="
<<
ALayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" B_Layout ="
<<
BLayout
::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" C_Layout ="
<<
CLayout
::
name
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
" A Type = "
<<
DataTypeTraits
<
ADataType
>::
name
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
" B Type = "
<<
DataTypeTraits
<
BDataType
>::
name
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
<<
" C Type = "
<<
DataTypeTraits
<
CDataType
>::
name
<<
" : "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
return
ave_time
;
return
ave_time
;
}
}
...
@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
ADataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
ADataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
BDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
BDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
CDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
CDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
using
AccDataType
=
typename
GemmBasicTypeConfig
<
PrecType
>::
AccDataType
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
...
@@ -144,19 +146,19 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -144,19 +146,19 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
invoke_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
ALayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
a_m_k_dev_buf
,
b_k_n_dev_buf
,
b_k_n_dev_buf
,
c_m_n_dev_buf
,
c_m_n_dev_buf
,
M
,
M
,
N
,
N
,
K
,
K
,
stride_A
,
stride_A
,
stride_B
,
stride_B
,
stride_C
,
stride_C
,
kbatch
,
kbatch
,
n_warmup
,
n_warmup
,
n_repeat
);
n_repeat
);
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
bool
pass
=
true
;
...
@@ -171,9 +173,9 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -171,9 +173,9 @@ int run_gemm_example_with_layouts(int argc,
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
(
K
,
kbatch
,
max_accumulated_value
);
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
...
@@ -182,7 +184,7 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -182,7 +184,7 @@ int run_gemm_example_with_layouts(int argc,
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
{
...
@@ -229,9 +231,9 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -229,9 +231,9 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
const
float
max_accumulated_value
=
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
const
auto
rtol_atol
=
calculate_rtol_atol
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
(
K
,
kbatch
,
max_accumulated_value
);
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
...
@@ -240,7 +242,7 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -240,7 +242,7 @@ int run_gemm_example_with_layouts(int argc,
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The GPU ve
r
ification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
return
pass
;
return
pass
;
...
...
example/ck_tile/13_moe_sorting/moe_sorting.cpp
View file @
7dc420a8
...
@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
...
@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
.
insert
(
"k"
,
"4"
,
"topk"
)
.
insert
(
"k"
,
"4"
,
"topk"
)
.
insert
(
"unit"
,
"32"
,
"unit_size"
)
.
insert
(
"unit"
,
"32"
,
"unit_size"
)
.
insert
(
"moe_buf_size"
,
"0"
,
"moe_buf_size"
)
.
insert
(
"moe_buf_size"
,
"0"
,
"moe_buf_size"
)
.
insert
(
"local_eid"
,
"-1"
,
"a list of experts enabled as local expert. e.g.
\"
0,1,4,5
\"\n
"
"please make sure eid is in ascending order!"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"kname"
,
"0"
,
"when set to 1 it will print kernel name"
)
.
insert
(
"kname"
,
"0"
,
"when set to 1 it will print kernel name"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
...
@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int
kname
=
args
.
get_int
(
"kname"
);
int
kname
=
args
.
get_int
(
"kname"
);
int
warmup
=
args
.
get_int
(
"warmup"
);
int
warmup
=
args
.
get_int
(
"warmup"
);
int
repeat
=
args
.
get_int
(
"repeat"
);
int
repeat
=
args
.
get_int
(
"repeat"
);
int
max_output_ids
=
int
max_output_ids
=
ck_tile
::
integer_least_multiple
(
topk
*
tokens
+
num_experts
*
unit_size
-
topk
,
unit_size
);
ck_tile
::
integer_least_multiple
(
topk
*
tokens
+
num_experts
*
unit_size
-
topk
,
unit_size
);
...
@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
return
false
;
return
false
;
}
}
bool
local_expert_masking
=
args
.
get_str
(
"local_eid"
)
!=
"-1"
;
auto
local_expert_masking_host
=
[
&
]()
{
if
(
local_expert_masking
)
{
auto
local_eid
=
args
.
get_int_vec
(
"local_eid"
);
// std::vector<int> v_ {num_experts, 0};
ck_tile
::
HostTensor
<
IndexType
>
v_
{{
num_experts
}};
v_
.
SetZero
();
for
(
auto
eid
:
local_eid
)
{
if
(
eid
>=
num_experts
)
{
throw
std
::
runtime_error
(
"local_eid larger than number of expert, please check"
);
}
v_
.
mData
[
eid
]
=
1
;
}
return
v_
;
}
else
// return std::vector<int>{};
return
ck_tile
::
HostTensor
<
IndexType
>
{{
1
}};
}();
// tokens already considered batch size
// tokens already considered batch size
ck_tile
::
HostTensor
<
IndexType
>
topk_ids_host
({
tokens
,
topk
},
{
topk
,
1
});
ck_tile
::
HostTensor
<
IndexType
>
topk_ids_host
({
tokens
,
topk
},
{
topk
,
1
});
ck_tile
::
HostTensor
<
WeightType
>
weights_host
({
tokens
,
topk
},
{
topk
,
1
});
ck_tile
::
HostTensor
<
WeightType
>
weights_host
({
tokens
,
topk
},
{
topk
,
1
});
...
@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_expert_ids_host
.
get_element_space_size_in_bytes
());
sorted_expert_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_id_cnt_dev
(
sorted_id_cnt_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_id_cnt_dev
(
sorted_id_cnt_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
moe_buf_dev
(
moe_buf_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
moe_buf_dev
(
moe_buf_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
local_expert_masking_dev
(
local_expert_masking_host
.
get_element_space_size_in_bytes
());
topk_ids_dev
.
ToDevice
(
topk_ids_host
.
data
());
topk_ids_dev
.
ToDevice
(
topk_ids_host
.
data
());
weights_dev
.
ToDevice
(
weights_host
.
data
());
weights_dev
.
ToDevice
(
weights_host
.
data
());
...
@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{
{
moe_buf_dev
.
ToDevice
(
moe_buf_host
.
data
());
moe_buf_dev
.
ToDevice
(
moe_buf_host
.
data
());
}
}
if
(
local_expert_masking
)
local_expert_masking_dev
.
ToDevice
(
local_expert_masking_host
.
data
());
moe_sorting_trait
trait
{
index_prec
,
weight_prec
};
moe_sorting_trait
trait
{
index_prec
,
weight_prec
,
local_expert_masking
};
moe_sorting_args
karg
{
topk_ids_dev
.
GetDeviceBuffer
(),
moe_sorting_args
karg
{
topk_ids_dev
.
GetDeviceBuffer
(),
weights_dev
.
GetDeviceBuffer
(),
weights_dev
.
GetDeviceBuffer
(),
local_expert_masking
?
local_expert_masking_dev
.
GetDeviceBuffer
()
:
nullptr
,
sorted_ids_dev
.
GetDeviceBuffer
(),
sorted_ids_dev
.
GetDeviceBuffer
(),
sorted_weights_dev
.
GetDeviceBuffer
(),
sorted_weights_dev
.
GetDeviceBuffer
(),
sorted_expert_ids_dev
.
GetDeviceBuffer
(),
sorted_expert_ids_dev
.
GetDeviceBuffer
(),
...
@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
warmup
,
warmup
,
repeat
};
repeat
};
auto
ms
=
moe_sorting
(
trait
,
karg
,
sc
);
auto
ms
=
moe_sorting
(
trait
,
karg
,
sc
);
printf
(
"[%s|%s]tokens:%d, num_experts:%d, topk:%d,
ms:%f ,
"
,
printf
(
"[%s|%s]tokens:%d, num_experts:%d, topk:%d, "
,
index_prec
.
c_str
(),
index_prec
.
c_str
(),
weight_prec
.
c_str
(),
weight_prec
.
c_str
(),
tokens
,
tokens
,
num_experts
,
num_experts
,
topk
,
topk
);
ms
);
if
(
local_expert_masking
)
{
printf
(
"local_eid:%s, "
,
args
.
get_str
(
"local_eid"
).
c_str
());
}
if
(
ms
<
0
)
if
(
ms
<
0
)
printf
(
"not supported
\n
"
);
printf
(
"not supported
\n
"
);
else
printf
(
"ms:%f, "
,
ms
);
fflush
(
stdout
);
fflush
(
stdout
);
if
(
ms
<
0
)
if
(
ms
<
0
)
{
{
...
@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int32_t
ref_total_tokens_post_pad
=
0
;
int32_t
ref_total_tokens_post_pad
=
0
;
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
topk_ids_host
,
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
topk_ids_host
,
weights_host
,
weights_host
,
local_expert_masking_host
,
sorted_ids_ref
,
sorted_ids_ref
,
sorted_weights_ref
,
sorted_weights_ref
,
sorted_expert_ids_ref
,
sorted_expert_ids_ref
,
ref_total_tokens_post_pad
,
ref_total_tokens_post_pad
,
num_experts
,
num_experts
,
unit_size
);
unit_size
,
local_expert_masking
);
rtn
&=
ck_tile
::
check_err
(
rtn
&=
ck_tile
::
check_err
(
sorted_ids_host
,
sorted_ids_ref
,
std
::
string
(
"OUT Error: Incorrect ids!"
),
1e-6
,
1e-6
);
sorted_ids_host
,
sorted_ids_ref
,
std
::
string
(
"OUT Error: Incorrect ids!"
),
1e-6
,
1e-6
);
rtn
&=
ck_tile
::
check_err
(
sorted_weights_host
,
rtn
&=
ck_tile
::
check_err
(
sorted_weights_host
,
...
@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
moe_buf_host
,
moe_buf_ref
,
std
::
string
(
"OUT Error: Incorrect zero buf!"
),
0
,
0
);
moe_buf_host
,
moe_buf_ref
,
std
::
string
(
"OUT Error: Incorrect zero buf!"
),
0
,
0
);
}
}
rtn
&=
ref_total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
rtn
&=
ref_total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
printf
(
"total_tokens_post_pad:%d(%d), "
,
ref_total_tokens_post_pad
,
sorted_id_cnt_host
.
mData
[
0
]);
}
}
printf
(
"valid:%s
\n
"
,
rtn
?
"y"
:
"n"
);
printf
(
"valid:%s"
,
rtn
?
"y"
:
"n"
);
fflush
(
stdout
);
if
(
!
rtn
)
printf
(
", (%d)"
,
seed
);
printf
(
"
\n
"
);
fflush
(
stdout
);
fflush
(
stdout
);
return
rtn
;
return
rtn
;
}
}
...
...
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
7dc420a8
...
@@ -3,6 +3,12 @@
...
@@ -3,6 +3,12 @@
#include "moe_sorting_api.hpp"
#include "moe_sorting_api.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
...
@@ -17,6 +23,67 @@
...
@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
if(a.num_experts <= 8) \
{ \
{ \
...
@@ -38,11 +105,13 @@
...
@@ -38,11 +105,13 @@
{ \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
}
#endif
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
{
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
{
#if !MOE_SORTING_USE_EX_KERNEL
if
(
a
.
num_experts
>
127
)
if
(
a
.
num_experts
>
127
)
{
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
printf
(
"lds size exceed, only support experts <127
\n
"
);
...
@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
...
@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
MOE_SORTING_DISPATCH
(
4
);
MOE_SORTING_DISPATCH
(
4
);
}
}
}
}
#else
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
auto
[
r_
,
c_
]
=
ck_tile
::
moe_sorting_get_smem_row_col
(
a
.
tokens
,
a
.
num_experts
);
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
bool
is_local_expert_masking
=
t
.
local_expert_masking
;
(
void
)
c_
;
MOE_SORTING_DISPATCH_EMASK_
(
r_
);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
}
return
-
1
;
return
-
1
;
}
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment