Commit 0ef27d53 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Merge remote-tracking branch 'origin/gfx950' into andriy/lwpck-2682

parents 6778c318 74a743e2
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
...@@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ <arch> ...@@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation # The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation # The memory bound pipeline on the gemm calculation
make tile_example_gemm_mem_pipeline -j make tile_example_gemm_universal -j
``` ```
This will result in an executable `build/bin/tile_example_gemm_basic` This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal`
## example ## example
``` ```
...@@ -22,6 +22,9 @@ args: ...@@ -22,6 +22,9 @@ args:
-m m dimension (default:1024) -m m dimension (default:1024)
-n n dimension (default:2048) -n n dimension (default:2048)
-k k dimension (default:64) -k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: R)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0) -stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0) -stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0) -stride_c Tensor C stride (default:0)
......
...@@ -9,13 +9,11 @@ ...@@ -9,13 +9,11 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false; constexpr bool kPadM = false;
...@@ -79,17 +77,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -79,17 +77,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKernelArgs(args);
args.p_b,
args.p_c, const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs)) if(!Kernel::IsSupportedArgument(kargs))
......
...@@ -8,6 +8,27 @@ ...@@ -8,6 +8,27 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
...@@ -51,25 +72,10 @@ using BDataType = Types::BDataType; ...@@ -51,25 +72,10 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType; using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType; using CDataType = Types::CDataType;
struct gemm_basic_args
{
const void* p_a;
const void* p_b;
void* p_c;
ck_tile::index_t kbatch;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
};
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("m", "3840", "m dimension")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension") .insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default") .insert("a_layout", "R", "A tensor data layout - Row by default")
...@@ -82,11 +88,12 @@ auto create_args(int argc, char* argv[]) ...@@ -82,11 +88,12 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
// host API // host API
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
...@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
gemm_basic_args args; ck_tile::GemmHostArgs args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch; args.k_batch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
...@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc, ...@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_size = arg_parser.get_int("b"); ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
using namespace ck_tile::literals; using namespace ck_tile::literals;
...@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
batch_size, kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
...@@ -161,14 +161,39 @@ int run_gemm_example_with_layouts(int argc, ...@@ -161,14 +161,39 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref.SetZero(); c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType, ck_tile::reference_gemm_gpu<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
ALayout, ALayout,
BLayout, BLayout,
CLayout>( CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
......
#!/bin/sh
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
VALID=0
for b_matrix_layout in "R" "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do
$EXE -prec=fp16 -b=1 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=0
for b_matrix_layout in "R" "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do
$EXE -prec=fp16 -b=1 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done
...@@ -19,7 +19,27 @@ echo 'Host name: ' $host_name ...@@ -19,7 +19,27 @@ echo 'Host name: ' $host_name
export GPU_arch=$4 export GPU_arch=$4
echo 'GPU_arch: ' $GPU_arch echo 'GPU_arch: ' $GPU_arch
function print_log_header(){
rm -f $1;
echo 'On branch ' $3 &> $1;
echo 'Node name: ' $4 >> $1;
# get GPU architecture and compute units from rocminfo
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
rocminfo | grep "Compute Unit:" >> $1;
hipcc --version | grep -e 'HIP version' >> $1;
echo 'Environment type: ' $2 >> $1;
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
}
# run verification tests # run verification tests
example/ck_tile/03_gemm/script/smoke_test.sh example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
# run performance benchmarks
export gemm_basic_log="perf_tile_gemm_basic_fp16_$GPU_arch.log"
print_log_header $gemm_basic_log $env_type $branch $host_name
example/ck_tile/03_gemm/script/benchmark_basic.sh 2>&1 | tee -a $gemm_basic_log
# We do not have a performance benchmark for gemm yet. Will add it in the future. export gemm_mem_pipeline_log="perf_tile_gemm_mem_pipeline_fp16_$GPU_arch.log"
\ No newline at end of file print_log_header $gemm_mem_pipeline_log $env_type $branch $host_name
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh 2>&1 | tee -a $gemm_mem_pipeline_log
...@@ -32,4 +32,4 @@ set -x ...@@ -32,4 +32,4 @@ set -x
run_fp16_tests run_fp16_tests
set +x set +x
\ No newline at end of file
#!/bin/bash
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() {
for batch in 1 2; do
for m in 128 1024; do
for n in 128 2048; do
for k in 32 64; do
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
done
}
set -x
run_fp16_tests
set +x
...@@ -9,20 +9,11 @@ ...@@ -9,20 +9,11 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler // Memory friendly for Interwave scheduler
...@@ -71,14 +62,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -71,14 +62,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
#endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
...@@ -87,36 +79,22 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -87,36 +79,22 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) BDataType,
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3< AccDataType,
#endif GemmShape,
ck_tile::UniversalGemmPipelineProblem<ADataType, Traits,
BDataType, scheduler,
AccDataType, has_hot_loop_v,
GemmShape, tail_number_v>;
Traits,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
ck_tile::GemmPipelineScheduler::Interwave, using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) auto kargs = Kernel::MakeKernelArgs(args);
ck_tile::GemmPipelineScheduler::Intrawave,
#endif const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs)) if(!Kernel::IsSupportedArgument(kargs))
......
...@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// using WarpTile = ck_tile::sequence<1, 512>; // using WarpTile = ck_tile::sequence<1, 512>;
// using Vector = ck_tile::sequence<1, 8>; // using Vector = ck_tile::sequence<1, 8>;
constexpr ck_tile::index_t kBlockSize = 512; constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
std::cout << "grid size " << kGridSize << std::endl; std::cout << "grid size " << kGridSize << std::endl;
......
set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd")
set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
if(RMSNORM2D_FWD_ENABLE_APIS STREQUAL "all")
set(RMSNORM2D_FWD_ENABLE_APIS ${RMSNORM2D_FWD_KNOWN_APIS})
endif()
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/rmsnorm2d_fwd_blobs.txt RMSNORM2D_FWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${RMSNORM2D_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs
)
set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd") set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding ${TILE_RMSNORM2D_FWD}") message("adding ${TILE_RMSNORM2D_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp) add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp)
target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS) set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS)
......
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp" #include "ck_tile/ops/rmsnorm2d.hpp"
#include <cstring> #include <cstring>
...@@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
assert(stride >= n); assert(stride >= n);
using XDataType = DataType; using XDataType = DataType;
using YDataType = DataType; using YDataType = DataType;
using GammaDataType = DataType; using GammaDataType = DataType;
using InvRmsDataType = ck_tile::null_type; using InvRmsDataType = ck_tile::null_type;
using SmoothScaleDataType = ck_tile::null_type;
using YScaleDataType = ck_tile::null_type;
using ComputeDataType = float; using ComputeDataType = float;
...@@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
using BlockTile = ck_tile::sequence<2, 128>; using BlockTile = ck_tile::sequence<2, 128>;
using WarpTile = ck_tile::sequence<1, 64>; using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>; using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using PipelineTraits =
ck_tile::Rmsnorm2dFwdTraits<true, // kPadN
false, // kSaveInvRms
kTwoPass,
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP>; // fuse quant
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType, using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
GammaDataType, GammaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
InvRmsDataType, InvRmsDataType,
SmoothScaleDataType,
YScaleDataType,
Shape, Shape,
true, // kPadN PipelineTraits>;
false, // kSaveInvRms
kTwoPass>;
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<Problem>; using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<Problem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<Problem>; using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<Problem>;
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>; using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline>;
using Default2DEpilogueProblem = ck_tile::
Default2DEpilogueProblem<ComputeDataType, YDataType, false, PipelineTraits::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Default2DEpilogue>;
ck_tile::Rmsnorm2dFwdHostArgs args{x_buf.GetDeviceBuffer(), ck_tile::Rmsnorm2dFwdHostArgs args{x_buf.GetDeviceBuffer(),
nullptr,
nullptr,
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(),
nullptr, nullptr,
nullptr,
nullptr,
epsilon, epsilon,
m, m,
n, n,
stride,
stride,
stride,
stride}; stride};
auto kargs = Kernel::MakeKargs(args); auto kargs = Kernel::MakeKargs(args);
......
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
using trait_ = rmsnorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_>;
template <typename data_type>
float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
rmsnorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
float r = -1;
// clang-format off
// rm rn tm tn vn pd rms 2p
if(a.n <= 64) {
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 128) {
if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 256) {
if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 512) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 768) {
if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 1024) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 1536) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 2048) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 8, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 3072) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, false>>(s, a);
}
else if(a.n <= 4096) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, false>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, false>>(s, a);
}
else if(a.n > 4096) {
if (a.n % 8 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, true>>(s, a);
else if (a.n % 4 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, true>>(s, a);
else if (a.n % 2 == 0)
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, true>>(s, a);
else
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a);
}
return r;
// clang-format on
}
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
return rmsnorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return rmsnorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
}
else
throw std::runtime_error("Without supported instances!");
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment