Unverified Commit 909f519c authored by Harisankar Sadasivan's avatar Harisankar Sadasivan Committed by GitHub
Browse files

Merge branch 'develop' into universal_streamk

parents 406fa265 3bb0fe6c
...@@ -23,20 +23,7 @@ trigger: ...@@ -23,20 +23,7 @@ trigger:
- Jenkinsfile - Jenkinsfile
- LICENSE - LICENSE
pr: pr: none
autoCancel: true
branches:
include:
- develop
paths:
exclude:
- .github
- docs
- '.*.y*ml'
- '*.md'
- Jenkinsfile
- LICENSE
drafts: false
jobs: jobs:
- template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo - template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo
...@@ -117,7 +117,7 @@ else() ...@@ -117,7 +117,7 @@ else()
add_definitions(-DPROFILER_ONLY) add_definitions(-DPROFILER_ONLY)
set(GPU_TARGETS "" CACHE STRING "" FORCE) set(GPU_TARGETS "" CACHE STRING "" FORCE)
if(GPU_TARGETS) if(GPU_TARGETS)
message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12")
endif() endif()
if(GPU_ARCH MATCHES "gfx90") if(GPU_ARCH MATCHES "gfx90")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a")
...@@ -127,8 +127,10 @@ else() ...@@ -127,8 +127,10 @@ else()
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030")
elseif(GPU_ARCH MATCHES "gfx11") elseif(GPU_ARCH MATCHES "gfx11")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102")
elseif(GPU_ARCH MATCHES "gfx12")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201")
else() else()
message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12")
endif() endif()
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
endif() endif()
......
...@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){ ...@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){
def variant = env.STAGE_NAME def variant = env.STAGE_NAME
def retimage def retimage
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try { try {
(retimage, image) = getDockerImage(conf) (retimage, image) = getDockerImage(conf)
...@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM ...@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
pipeline { pipeline {
agent none agent none
triggers {
parameterizedCron(CRON_SETTINGS)
}
options { options {
parallelsAlwaysFailFast() parallelsAlwaysFailFast()
} }
......
...@@ -66,7 +66,7 @@ else() ...@@ -66,7 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror -Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
...@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout, < ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp, CElementOp,
GemmDefault, GemmDefault,
1, // Prefetch stage 1, // Prefetch stage
128, // BlockSize 128, // BlockSize
64, // MPerBlock 64, // MPerBlock
128, // NPerBlock 128, // NPerBlock
64, // KPerBlock 64, // KPerBlock
8, // K1 2, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 2,
8, 2,
true, true,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 2,
8, 2,
true, true,
1, // C shuffle (M Repeat) Per store 1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store 1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>, S<1, 32, 1, 4>,
8>; 8>;
// clang-format on // clang-format on
......
...@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break; break;
case 4: case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
break; break;
case 5: case 5:
......
...@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN = ...@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
2, 2,
4, 4,
4, 4,
true, false,
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
4, 4,
4, 4,
true, false,
1, 1,
1, 1,
S<1, 64, 1, 2>, S<1, 64, 1, 2>,
......
...@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial ...@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4 #define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8 //#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory = using DeviceMHAFactory =
std::tuple< std::tuple<
#ifdef CK_MHA_USE_WAVE_1 #ifdef CK_MHA_USE_WAVE_1
...@@ -277,10 +277,10 @@ using DeviceMHAFactory = ...@@ -277,10 +277,10 @@ using DeviceMHAFactory =
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8, 1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>, MaskingSpec>
#endif #endif
#ifdef CK_MHA_USE_WAVE_8 #ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
......
...@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial ...@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4 #define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8 //#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory = using DeviceMHAFactory =
std::tuple< std::tuple<
#ifdef CK_MHA_USE_WAVE_1 #ifdef CK_MHA_USE_WAVE_1
...@@ -277,10 +277,10 @@ using DeviceMHAFactory = ...@@ -277,10 +277,10 @@ using DeviceMHAFactory =
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8, 1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>, MaskingSpec>
#endif #endif
#ifdef CK_MHA_USE_WAVE_8 #ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
......
...@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach() endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list #Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME) foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ") message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
...@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach() endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list #Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME) foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ") message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
......
# generate a list of kernels, but not actually emit files at config stage # generate a list of kernels, but not actually emit files at config stage
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
) )
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
) )
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
...@@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) ...@@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command( add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS} OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}
) )
add_custom_command( add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS} OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
) )
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
GEN_DIR = "" # in Cmake, have to generate files in same folder
\ No newline at end of file
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
DTYPE_MAP = {
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp8" : "ck_tile::fp8_t"
}
MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
}
_MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
"generic" : "FmhaMasks::GenericMask"
}
def get_mask_map(mask : str):
if mask == "generic":
return _MASK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
_MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask : str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
}
# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
}
MODE_MAP = {
"batch" : "false",
"group" : "true"
}
LAYOUT_MAP = {
"row" : "true",
"col" : "false"
}
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t" : "true",
"f" : "false"
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[]) ...@@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[])
.insert("drop_seed", "1", "seed for random number generator") .insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator") .insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("num_splits",
"1",
"# of splits for key/value. 0 to determine actual number by heuristic")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel"); .insert("repeat", "20", "number of iterations to benchmark the kernel");
...@@ -155,6 +158,106 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method) ...@@ -155,6 +158,106 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method)
} }
} }
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 ||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
efficiency.push_back(0.f);
}
else
{
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
}
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
continue;
}
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return num_splits;
}
hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return num_splits;
}
// tile size should match the generate.py
const int kM0 = 64;
const int kN1 = hdim_v;
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
if(num_splits < 1 && p_drop == 0.0f)
{
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
}
return num_splits;
}
float fmha_fwd_dispatch(fmha_fwd_traits traits,
fmha_fwd_args args,
const ck_tile::stream_config& config)
{
if(1 < args.num_splits)
{
return fmha_fwd_splitkv(traits, args, config);
}
else
{
return fmha_fwd(traits, args, config);
}
}
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
...@@ -260,6 +363,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -260,6 +363,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed.reset(); seed.reset();
} }
int num_splits = arg_parser.get_int("num_splits");
int stream_warmup = arg_parser.get_int("warmup"); int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat"); int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname"); bool kname = arg_parser.get_bool("kname");
...@@ -320,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -320,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
} }
// legalize num_splits according to other options
if(num_splits < 1)
{
num_splits = override_num_splits_if_necessary(
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
}
if(128 < num_splits)
{
std::cerr << "num_splits greater than 128 is not supported" << std::endl;
return false;
}
auto get_lengths = [&](bool permute, auto get_lengths = [&](bool permute,
ck_tile::index_t b /*batch*/, ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/, ck_tile::index_t h /*nhead*/,
...@@ -361,7 +478,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -361,7 +478,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{batch, nhead}) : std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1}); : std::array<ck_tile::index_t, 2>{1, 1});
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q] ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// self define lse data layout as [batch, nhead, max_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q} lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
...@@ -443,6 +568,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -443,6 +568,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
...@@ -479,7 +606,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -479,7 +606,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
<< ", mask:" << mask << ", v:" << vlayout << std::flush; << ", mask:" << mask << ", v:" << vlayout;
if(1 < num_splits)
{
std::cout << ", num_splits:" << num_splits;
}
std::cout << std::flush;
auto fmha_traits = fmha_fwd_traits{hdim_q, auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v, hdim_v,
...@@ -523,6 +655,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -523,6 +655,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}(); }();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = hdim_v;
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments // setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
...@@ -537,6 +670,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -537,6 +670,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = max_seqlen_q; const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments // setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
...@@ -545,7 +680,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -545,7 +680,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(), return fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
...@@ -553,6 +693,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -553,6 +693,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(), : bias_buf.GetDeviceBuffer(),
randval_buf.GetDeviceBuffer(), randval_buf.GetDeviceBuffer(),
lse_acc_buf.GetDeviceBuffer(),
o_acc_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
...@@ -566,6 +708,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -566,6 +708,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
hdim_v, hdim_v,
nhead, nhead,
nhead_k, nhead_k,
num_splits,
scale_s, scale_s,
scale_p, scale_p,
scale_o, scale_o,
...@@ -575,6 +718,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -575,6 +718,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias, : stride_bias,
stride_randval, stride_randval,
stride_o_acc,
stride_o, stride_o,
nhead_stride_q, nhead_stride_q,
nhead_stride_k, nhead_stride_k,
...@@ -582,6 +726,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -582,6 +726,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_bias, nhead_stride_bias,
nhead_stride_randval, nhead_stride_randval,
nhead_stride_lse, nhead_stride_lse,
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
...@@ -589,7 +735,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -589,7 +735,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_bias, batch_stride_bias,
batch_stride_randval, batch_stride_randval,
batch_stride_lse, batch_stride_lse,
batch_stride_lse_acc,
batch_stride_o_acc,
batch_stride_o, batch_stride_o,
split_stride_lse_acc,
split_stride_o_acc,
mask.left, mask.left,
mask.right, mask.right,
static_cast<ck_tile::index_t>(mask.type), static_cast<ck_tile::index_t>(mask.type),
...@@ -598,7 +748,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -598,7 +748,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{drop_seed, drop_offset}}; {drop_seed, drop_offset}};
}(); }();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
if(ave_time < 0) if(ave_time < 0)
{ {
...@@ -849,14 +999,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -849,14 +999,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_host_result.ForEach( lse_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
bool lse_pass = ck_tile::check_err(lse_host_result, cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref, lse_host_ref,
"LSE Error: Incorrect results!", "LSE Error: Incorrect results!",
rtol, rtol,
atol, atol,
/* allow_infinity_ref = */ true); /* allow_infinity_ref = */ true);
pass &= lse_pass; pass &= cur_pass;
if(!cur_pass) if(!cur_pass)
{ {
std::cerr << "LSE mismatch found at batch: " << wb << std::endl std::cerr << "LSE mismatch found at batch: " << wb << std::endl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment