Unverified Commit 3d609534 authored by rocking's avatar rocking Committed by GitHub
Browse files

[Ck tile] support rmsnorm and related fusion (#1605)

* Add reduce2d new api

* Prevent user use cross warp reduction

* Fix bug of std caculation

* Add rmsnorm2d

* Add rmsnorm small example

* Remove static assert to prevent compile fail

* Add script to test performance and correctness

* Add missing cmake change

* refine naming

* refine example of rmsnorm

* Fix bug of rmsnorm

* Refine naming

* Fix cmake

* clang format

* Refine pipeline name

* Add add_rmsnorm2d_rdquant kernel

* Add reduce op

* host verification

* Fix bug of one pass pipeline

* Refine tile size

* Add two pass pipeline

* Rename two pass to three pass

* Fix bug of kSaveX == false

* Add instance library

* Add test script

* Fix bug of x verification

* Add save_x to trait

* Add README

* Move reduce2d into reduce folder

* Fix bug of welford when number of m warp > 1

* remove reduncant comment

* 1. move 06_rmsnorm2d to 10_rmsnorm2d
2. move 07_add_rmsnorm2d_rdquant to 11_add_rmsnorm2d_rdquant

* clang format and add missing header

* Add host validation of add + layernorm2d + rsquant

* Revert "Add host validation of add + layernorm2d + rsquant"

This reverts commit 936cb457978b928b90eff89a08fcdb7dc8bbed67.

* Remove deprecated flag
parent 86322218
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = rmsnorm2d_fwd_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
using trait_ = rmsnorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_>;
template <typename Traits_>
float rmsnorm2d_fwd_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem =
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<DataType>::XDataType,
typename RmsnormTypeConfig<DataType>::GammaDataType,
typename RmsnormTypeConfig<DataType>::ComputeDataType,
typename RmsnormTypeConfig<DataType>::YDataType,
typename RmsnormTypeConfig<DataType>::InvRmsDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveInvRms,
Traits_::kTwoPass>;
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
#include "ck_tile/host.hpp"
#include "rmsnorm2d_fwd.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType, bool SaveRms>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
float epsilon = arg_parser.get_float("e");
std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
using TypeConfig = RmsnormTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType;
using YDataType = typename TypeConfig::YDataType;
using GammaDataType = typename TypeConfig::GammaDataType;
using InvRmsDataType =
std::conditional_t<SaveRms, typename TypeConfig::InvRmsDataType, ck_tile::null_type>;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data());
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
rmsnorm2d_fwd_traits traits{data_type, SaveRms};
rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(),
nullptr,
epsilon,
m,
n,
stride};
float ave_time = rmsnorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte =
sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
bool pass = true;
if(do_validation)
{
// reference
ck_tile::reference_rmsnorm2d_fwd<XDataType,
GammaDataType,
ComputeDataType,
YDataType,
InvRmsDataType>(
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon);
y_buf.FromDevice(y_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>();
if(stride == n)
{
pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
y_host_dev.begin() + i_r * stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
y_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
int save_rms = arg_parser.get_int("save_rms");
if(data_type == "fp16" && save_rms)
{
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp16" && !save_rms)
{
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16" && save_rms)
{
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16" && !save_rms)
{
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
}
return -3;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp"
#include <string>
template <typename DataType>
struct RmsnormTypeConfig;
template <>
struct RmsnormTypeConfig<ck_tile::half_t>
{
using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using InvRmsDataType = ck_tile::half_t;
using ComputeDataType = float;
};
template <>
struct RmsnormTypeConfig<ck_tile::bf16_t>
{
using XDataType = ck_tile::bf16_t;
using YDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t;
using InvRmsDataType = ck_tile::bf16_t;
using ComputeDataType = float;
};
// runtime args
struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
struct rmsnorm2d_fwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Rmsnorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_>
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
// This is the public API, will be generated by script
struct rmsnorm2d_fwd_traits
{
std::string data_type;
bool save_rms;
};
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);
# run from top of ck folder
EXE=build/bin/tile_rmsnorm2d_fwd
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
\ No newline at end of file
#!/bin/sh
# call from top of CK folder
EXE=./build/bin/tile_rmsnorm2d_fwd
for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13
$EXE -prec=$pr_i -m=17 -n=16
$EXE -prec=$pr_i -m=1 -n=100
$EXE -prec=$pr_i -m=4 -n=128
$EXE -prec=$pr_i -m=80 -n=127
$EXE -prec=$pr_i -m=22 -n=255 -stride=256
$EXE -prec=$pr_i -m=7 -n=599
$EXE -prec=$pr_i -m=19 -n=512
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000
$EXE -prec=$pr_i -m=11 -n=510
$EXE -prec=$pr_i -m=171 -n=676 -stride=818
$EXE -prec=$pr_i -m=91 -n=636
$EXE -prec=$pr_i -m=12 -n=768 -stride=800
$EXE -prec=$pr_i -m=100 -n=766 -stride=812
$EXE -prec=$pr_i -m=31 -n=1024
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004
$EXE -prec=$pr_i -m=8 -n=1501
$EXE -prec=$pr_i -m=3 -n=1826
$EXE -prec=$pr_i -m=5 -n=2040
$EXE -prec=$pr_i -m=7 -n=2734
$EXE -prec=$pr_i -m=1 -n=3182
$EXE -prec=$pr_i -m=9 -n=4096
$EXE -prec=$pr_i -m=3 -n=8192
$EXE -prec=$pr_i -m=1 -n=10547
$EXE -prec=$pr_i -m=3 -n=17134
done
set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp)
target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS})
set(TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS})
set(EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD "tile_example_add_rmsnorm2d_rdquant_fwd")
add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL example_add_rmsnorm2d_rdquant_fwd.cpp)
target_compile_options(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
# Add + Rmsnorm2D + rowwise dynamic quantization forward
This folder contains example for add + Rmsnorm2D + rowwise dynamic quantization forward using ck_tile tile-programming implementation. Rdquant is short for rowwise dynamic quantization here.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_add_rmsnorm2d_rdquant_fwd -j
```
This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd`
## cmdline
```
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-e epsilon (default:1e-5)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
```
#include "ck_tile/host.hpp"
#include "add_rmsnorm2d_rdquant_fwd.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("save_x", "1", "save rms(invrms) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType, bool SaveX>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
float epsilon = arg_parser.get_float("e");
std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
using TypeConfig = AddRmsnormRdquantTypeConfig<DataType>;
using ADataType = typename TypeConfig::ADataType;
using BDataType = typename TypeConfig::BDataType;
using GammaDataType = typename TypeConfig::GammaDataType;
using XDataType = typename TypeConfig::XDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = float;
// host verify
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
ck_tile::HostTensor<BDataType> b_host({m, n}, {stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<XDataType> x_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<XDataType> x_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
a_buf.ToDevice(a_host.data());
b_buf.ToDevice(b_host.data());
gamma_buf.ToDevice(gamma_host.data());
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
add_rmsnorm2d_rdquant_fwd_traits traits{data_type, SaveX};
add_rmsnorm2d_rdquant_fwd_args args{a_buf.GetDeviceBuffer(),
b_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(),
x_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
epsilon,
m,
n,
stride};
float ave_time = add_rmsnorm2d_rdquant_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(BDataType) * m * n +
sizeof(GammaDataType) * n + sizeof(YScaleDataType) * m +
sizeof(QYDataType) * m * n;
if constexpr(SaveX)
num_byte += sizeof(XDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
bool pass = true;
if(do_validation)
{
using YDataType = ComputeDataType;
using InvRmsDataType = DataType;
// Add
{
auto op = [](const auto& v0, const auto& v1) { return v0 + v1; };
ck_tile::reference_binary_elementwise<ADataType, BDataType, XDataType, ComputeDataType>(
a_host, b_host, x_host_ref, op);
x_buf.FromDevice(x_host_dev.data());
auto [rtol, atol] = get_elimit<XDataType>();
if(stride == n)
{
pass = ck_tile::check_err(
x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> x_host_dev_row(x_host_dev.begin() + i_r * stride,
x_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> x_host_ref_row(x_host_ref.begin() + i_r * stride,
x_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(x_host_dev_row,
x_host_ref_row,
std::string("x[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
ck_tile::HostTensor<YDataType> y_host({m, n});
// Rmsnorm2d
{
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
// simplicity
ck_tile::reference_rmsnorm2d_fwd<XDataType,
GammaDataType,
ComputeDataType,
YDataType,
InvRmsDataType>(
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<YDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
int save_x = arg_parser.get_int("save_x");
if(data_type == "fp16" && save_x)
{
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp16" && !save_x)
{
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16" && save_x)
{
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16" && !save_x)
{
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
}
return -3;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp"
#include <string>
template <typename DataType>
struct AddRmsnormRdquantTypeConfig;
template <>
struct AddRmsnormRdquantTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using XDataType = ck_tile::half_t;
using YScaleDataType = ck_tile::half_t;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
template <>
struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t;
using XDataType = ck_tile::bf16_t;
using YScaleDataType = ck_tile::bf16_t;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
// runtime args
struct add_rmsnorm2d_rdquant_fwd_args : public ck_tile::AddRmsnorm2dRdquantFwdHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveX_,
bool kThreePass_>
struct add_rmsnorm2d_rdquant_fwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::AddRmsnorm2dRdquantShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveX = kSaveX_;
static constexpr bool kThreePass = kThreePass_;
};
template <typename Traits_>
float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_rdquant_fwd_args a);
// This is the public API, will be generated by script
struct add_rmsnorm2d_rdquant_fwd_traits
{
std::string data_type;
bool save_x;
};
float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits,
add_rmsnorm2d_rdquant_fwd_args,
const ck_tile::stream_config&);
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "0", "cold iter")
.insert("repeat", "1", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
float epsilon = arg_parser.get_float("e");
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
using ADataType = DataType;
using BDataType = DataType;
using GammaDataType = DataType;
using XDataType = DataType;
using YScaleDataType = DataType;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
// host verify
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
ck_tile::HostTensor<BDataType> b_host({m, n}, {stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<XDataType> x_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<XDataType> x_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
a_buf.ToDevice(a_host.data());
b_buf.ToDevice(b_host.data());
gamma_buf.ToDevice(gamma_host.data());
constexpr bool kThreePass = true;
using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 128>;
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::AddRmsnorm2dRdquantShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<ADataType,
BDataType,
GammaDataType,
ComputeDataType,
XDataType,
YScaleDataType,
QYDataType,
Shape,
true, // kPadN
true, // kSaveX
kThreePass>;
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<Problem>;
using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass<Problem>;
using Pipeline = std::conditional_t<kThreePass, ThreePassPipeline, OnePassPipeline>;
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
ck_tile::AddRmsnorm2dRdquantFwdHostArgs args{a_buf.GetDeviceBuffer(),
b_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(),
x_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
epsilon,
m,
n,
stride};
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
bool pass = true;
if(do_validation)
{
using YDataType = ComputeDataType;
using InvRmsDataType = DataType;
// Add
{
auto op = [](const auto& v0, const auto& v1) { return v0 + v1; };
ck_tile::reference_binary_elementwise<ADataType, BDataType, XDataType, ComputeDataType>(
a_host, b_host, x_host_ref, op);
x_buf.FromDevice(x_host_dev.data());
auto [rtol, atol] = get_elimit<XDataType>();
if(stride == n)
{
pass = ck_tile::check_err(
x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> x_host_dev_row(x_host_dev.begin() + i_r * stride,
x_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> x_host_ref_row(x_host_ref.begin() + i_r * stride,
x_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(x_host_dev_row,
x_host_ref_row,
std::string("x[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
ck_tile::HostTensor<YDataType> y_host({m, n});
// Rmsnorm2d
{
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
// simplicity
ck_tile::reference_rmsnorm2d_fwd<XDataType,
GammaDataType,
ComputeDataType,
YDataType,
InvRmsDataType>(
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<YDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
return -3;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "add_rmsnorm2d_rdquant_fwd.hpp"
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveX_,
bool kThreePass_>
using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveX_,
kThreePass_>;
template <typename data_type>
float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
add_rmsnorm2d_rdquant_fwd_args a,
const ck_tile::stream_config& s)
{
#if 1
float r = -1;
// clang-format off
// rm rn tm tn vn pd x 3p
if(a.n <= 64) {
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, true, false>>(s, a);
}
else if(a.n <= 128) {
if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, true, false>>(s, a);
}
else if(a.n <= 256) {
if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, true, false>>(s, a);
}
else if(a.n <= 512) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, true, false>>(s, a);
}
else if(a.n <= 768) {
if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, true, false>>(s, a);
}
else if(a.n <= 1024) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, true, false>>(s, a);
}
else if(a.n <= 1536) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 3, 4, 64, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 3, 2, 128, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 3, 1, 256, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 6, 1, 256, 1, true, true, false>>(s, a);
}
else if(a.n <= 2048) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 8, 1, 256, 1, true, true, false>>(s, a);
}
else if(a.n <= 3072) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 3, 1, 128, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 3, 1, 256, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 6, 1, 256, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 3, 1, 1024, 1, true, true, false>>(s, a);
}
else if(a.n <= 4096) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, true, false>>(s, a);
}
else if(a.n > 4096) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, true, true>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, true, true>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, true, true>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, true, true>>(s, a);
}
return r;
#else
return add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, true, false>>(s, a);
#endif
// clang-format on
}
float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
add_rmsnorm2d_rdquant_fwd_args a,
const ck_tile::stream_config& s)
{
float r = -1;
// Only support instance of save_x == true for now
assert(t.save_x);
if(t.data_type.compare("fp16") == 0)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t>(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t>(t, a, s);
}
if(r < 0)
throw std::runtime_error("Without supported instances!");
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
#if 0
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , true, false>>(const S&, A);
#endif
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, true, false>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment