Unverified Commit 58d75b7a authored by M.Emin Ozturk's avatar M.Emin Ozturk Committed by GitHub
Browse files

Merge branch 'develop' into gemm_bf16_sk_muozturk

parents 7ed95722 627a27bd
# reference
this folder contains reference implementation of a specific op. Note by including a specific header, you are including the implementation(expecially the gpu implementation) into your source code, and compile that kernel into the fatbin, hence may increase your kernel obj code length. Usually the header starts with `reference_` is a cpu reference implementation. The header starts with `naive_` contains a gpu implementation with a small launcher.
TODO: move `host/reference` under this folder
This diff is collapsed.
......@@ -7,6 +7,7 @@ import copy
NS = 'ck_tile'
OPS = 'ops'
REF = 'ref'
OPS_COMMON = 'common' # common header will be duplicated into ops/* other module
HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
......@@ -29,6 +30,9 @@ class submodule_t:
def push(self, f):
if len(f.parents) != 1: # ignore ./xxx.hpp
mod = get_module(f)
# ref is supposed to include one header on demand
if mod == REF:
return
if mod == OPS:
if mod not in self.m.keys():
self.m[mod] = dict()
......
......@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances =
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 64, 8, 8, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 64, 8, 8, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
......
......@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef __gfx94__
// Compute friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
......@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std:
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
......
......@@ -6,7 +6,7 @@ set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/)
# CK Codegen requires dataclass which is added in Python 3.7
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04
if(NOT CK_USE_ALTERNATIVE_PYTHON)
find_package(PythonInterp 3 REQUIRED)
find_package(Python3 COMPONENTS Interpreter Development)
else()
message("Using alternative python version")
set(EXTRA_PYTHON_PATH)
......@@ -33,7 +33,7 @@ set(FMHA_KNOWN_APIS "fwd,fwd_splitkv,fwd_appendkv,bwd")
# Note: The receipt 3 arg filters the generated backwards instances to reduce compilation time.
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py
COMMAND ${Python3_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py
--list_blobs ${FMHA_CPP_FOLDER}/blob_list.txt
--api ${FMHA_KNOWN_APIS}
--receipt 3
......@@ -50,7 +50,7 @@ endif()
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
add_custom_command(
OUTPUT ${FMHA_GEN_BLOBS}
COMMAND ${PYTHON_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py
COMMAND ${Python3_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py
--output_dir ${FMHA_CPP_FOLDER}
--api ${FMHA_KNOWN_APIS}
--receipt 3
......
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp
example/01_gemm/run_gemm_example_streamk_v2.inc
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
profiler/src/profile_gemm_universal_streamk.cpp
modified_files.txt
......@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
int StrideB,
int StrideC,
int BatchCount,
int KBatch,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
......@@ -147,89 +148,100 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
std::unique_ptr<tensor_operation::device::BaseArgument> argument_ptr;
// false branch for multi d dl kernel
argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
BatchCount,
StrideA,
StrideB,
{},
StrideC,
BatchStrideA,
BatchStrideB,
{},
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString();
std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
float ave_time = invoker_ptr->Run(
argument_ptr.get(),
StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter, true, rotating_count});
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
BatchCount,
StrideA,
StrideB,
{},
StrideC,
BatchStrideA,
BatchStrideB,
{},
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
kbatch_curr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string op_name = op_ptr->GetTypeString();
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N) *
BatchCount;
float ave_time = invoker_ptr->Run(
argument_ptr.get(),
StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter, true, rotating_count});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N) *
BatchCount;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
float gb_per_sec = num_btype / 1.E6 / ave_time;
if(do_verification)
{
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl;
pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
if(do_log)
if(do_verification)
{
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
<< std::endl;
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
......@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K
<< " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC
<< ": " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
<< " KBatch = " << best_kbatch << ": " << best_ave_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
}
......
......@@ -82,7 +82,7 @@ def parse_logfile(logfile):
StrideA=[]
StrideB=[]
StrideC=[]
if 'perf_gemm.log' in logfile:
if 'perf_gemm' in logfile and 'gemm_bilinear' not in logfile:
for line in open(logfile):
if 'Best Perf' in line:
lst=line.split()
......@@ -260,7 +260,7 @@ def main():
conn = sqlEngine.connect()
#save gemm performance tests:
if 'perf_gemm.log' in filename:
if 'perf_gemm' in filename and 'gemm_bilinear' not in filename:
#write the ck_gemm_test_params table only needed once the test set changes
#post_test_params(test_list,conn)
for i in range(1,len(results)+1):
......@@ -332,7 +332,7 @@ def main():
table_name="ck_fmha_bwd_tflops"
tflops_base = get_baseline(table_name,conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
conn.close()
#compare the results to the baseline if baseline exists
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment