Unverified Commit 1837040a authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Navi3 rel (#1176)



* wmma_op + unit test

* add arch limitation to wmma test

* change arch limitation

* Refactor + Add all type unit test(int4 compile failed)

* Add f32_16x16x16_bf16 unit test

* tempsave

* tempsave

* tempsave

* runtime bug, cannot find symbol

* workaround for incorrect HIP warpSize return value

* debugging

* tempsave

* Correctness OK, waiting for optimization

* Tidy up + format

* temp save

* temp save, reproduce the v_bfi_b32 issue

* add inline asm for wmmaop test

* tidy up

* clean some debug purpose code

* discard some codes

* clang format

* clang format

* compiler issue fixed + increase tile size

* navi3x_multipleD+example

* temp save

* workable

* batchedgemm[OK], groupconv[debug]

* groupconv: Sanity check[OK], Performance[Bad]

* navi3x_groupconv_need_optimization

* create necessary files

* save progress

* Add Inter-Row thread transfer

* save progress

* save debugging progress

* sanity check pass

* fix a host tensor bug and clean up flash-attn code

* format

* cancel unnecessary change

* cancel unnecessary change

* cancel unnecessary change

* temp save, add asm backend flag to amd_wmma

* Mat-A LDS Bypass sanity pass

* temp save

* gemm sanity fix

* Porting new blockwise gemm to flash attention

* Example branch provide to compiler team

* tempsave

* Fix a bug

* batched gemm ported

* conv A-skip lds ported

* Skip B-Lds real gemm

* Skip B Lds Gemm + MulD

* batched gemm, conv, skip b lds

* format

* Attn, skip b lds

* Change GridwiseOp nam

* fix a typo caused bug

* Skip A_Lds sanity pass, Skip B_Lds scratch occured

* Bug found, intra-row permute off caused

* bug found

* a fix

* disable buffer load due to incorrect 3rd dword

* update fmha config, no scratch generated

* update 3rd dword

* fmha config update

* FMHA, add support to gfx1101/gfx1102

* Merge origin dev (#2)

* [Navi3x] Fix Gridwise_multiple_d operation (#649)

* Add CMake Option "USE_OPT_NAVI3X"

* fix bug

* standardize docs (#655)

* Separate bibtex requirement from rocm-docs-core (#656)

* separate bibtex requirement from rocm-docs-core

* point requirements to source rocm-docs-core repo

* Add CMake Option "USE_OPT_NAVI3X" (#647)

* Add CMake Option "USE_OPT_NAVI3X"

* remove navi3x opt compile option from cmake script

* Conv + quantization + tanh  (#645)

* Rename file. Prepare to support another activation

* Add comment for quantization

* Extract out_elementop

* Add tanh example

* Add conv + bias + tanh quantization instance

* Add missing parameter

* Refine cmake

* Add external api and client example

* Extract variable in example

* Fix the comment

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>

* Add a denorm test fix (#603)

* Add type_convert implementations for bf16

* Add the fix for conv_fwd

* Add the fix for conv_bwd_data

* Add the fix for conv_bwd_weight

* Format

* Format

* Another format

* Add a macro to use workaround on MI200 only

* Format

---------
Co-authored-by: default avatarRosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>

* simplify karg in device/grid of split-k op (#644)

* simplify karg in device/grid split-k op

* fix mk_kn_mn instances

* add more instances

* use name from tensor layout

* fix 3rd dword of buffer source descriptor (#659)

* add fp64 instances (#658)
Co-authored-by: default avatarroot <root@ctr-ubbsmc15.amd.com>

* Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665)

This reverts commit bb5530af

.

* Groupnorm + swish external api (#668)

* Rename to proper naming

* Add example of groupnorm + swish

* Extract duplicate code in example

* Add groupnorm + swish instances

* Ractor instance generation, split into multiple cpp file

* Add external api and client example

* Refine profiler message

* Use ck math version of exp

* Refine problem size in example

* Add host version of exp

* add a marco to turn on/off denorm fix (off by default) (#673)

* add a marco to turn off denorm fix by default

* expose the marco

---------
Co-authored-by: default avatarroot <root@ctr-ubbsmc15.amd.com>

* fixed quant example (#672)
Co-authored-by: default avatarroot <root@ctr-ubbsmc15.amd.com>

* Add dependabot config and pin rocm-docs-core (#663)

* [gtest] suppress unsafe buffer warn (#670)

ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912



* Add memory index guard in wmma device ops (#667)

* Add more macros to turn on/off denorm fix (#678)
Co-authored-by: default avatarRosty Geyyer <rosty.geyyer@amd.com>

* Fix a typo (#676)

* Add (#677)

* Allow using ROCm release candidate compilers. (#679)

* enable use of rocm5.5 release candidate 4

* upgrade to ROCM5.5 RC5

* try fix the PUB_KEY error, remove the cmake-data package

* upgrade to latest cmake version

* use private dockerhub repo for rocm5.5 rc5

* add missing bracket

* add vector load check

* solve conflicts

---------
Co-authored-by: default avatarSam Wu <sjwu@ualberta.ca>
Co-authored-by: default avatarSam Wu <sam.wu2@amd.com>
Co-authored-by: default avatarrocking5566 <ChunYu.Lai@amd.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarRostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: default avatarRosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarroot <root@ctr-ubbsmc15.amd.com>
Co-authored-by: default avatarJun Liu <Liu.Jun@amd.com>
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>

* Disable SkipLDS & Align AIT api (#3)

* fix layernorm, reduction Ops (#4)

* [Navi3x] Fix Gridwise_multiple_d operation (#649)

* Add CMake Option "USE_OPT_NAVI3X"

* fix bug

* standardize docs (#655)

* Separate bibtex requirement from rocm-docs-core (#656)

* separate bibtex requirement from rocm-docs-core

* point requirements to source rocm-docs-core repo

* Add CMake Option "USE_OPT_NAVI3X" (#647)

* Add CMake Option "USE_OPT_NAVI3X"

* remove navi3x opt compile option from cmake script

* Conv + quantization + tanh  (#645)

* Rename file. Prepare to support another activation

* Add comment for quantization

* Extract out_elementop

* Add tanh example

* Add conv + bias + tanh quantization instance

* Add missing parameter

* Refine cmake

* Add external api and client example

* Extract variable in example

* Fix the comment

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>

* Add a denorm test fix (#603)

* Add type_convert implementations for bf16

* Add the fix for conv_fwd

* Add the fix for conv_bwd_data

* Add the fix for conv_bwd_weight

* Format

* Format

* Another format

* Add a macro to use workaround on MI200 only

* Format

---------
Co-authored-by: default avatarRosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>

* simplify karg in device/grid of split-k op (#644)

* simplify karg in device/grid split-k op

* fix mk_kn_mn instances

* add more instances

* use name from tensor layout

* fix 3rd dword of buffer source descriptor (#659)

* add fp64 instances (#658)
Co-authored-by: default avatarroot <root@ctr-ubbsmc15.amd.com>

* Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665)

This reverts commit bb5530af

.

* Groupnorm + swish external api (#668)

* Rename to proper naming

* Add example of groupnorm + swish

* Extract duplicate code in example

* Add groupnorm + swish instances

* Ractor instance generation, split into multiple cpp file

* Add external api and client example

* Refine profiler message

* Use ck math version of exp

* Refine problem size in example

* Add host version of exp

* add a marco to turn on/off denorm fix (off by default) (#673)

* add a marco to turn off denorm fix by default

* expose the marco

---------
Co-authored-by: default avatarroot <root@ctr-ubbsmc15.amd.com>

* fixed quant example (#672)
Co-authored-by: default avatarroot <root@ctr-ubbsmc15.amd.com>

* Add dependabot config and pin rocm-docs-core (#663)

* [gtest] suppress unsafe buffer warn (#670)

ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912



* Add memory index guard in wmma device ops (#667)

* Add more macros to turn on/off denorm fix (#678)
Co-authored-by: default avatarRosty Geyyer <rosty.geyyer@amd.com>

* Fix a typo (#676)

* Add (#677)

* Allow using ROCm release candidate compilers. (#679)

* enable use of rocm5.5 release candidate 4

* upgrade to ROCM5.5 RC5

* try fix the PUB_KEY error, remove the cmake-data package

* upgrade to latest cmake version

* use private dockerhub repo for rocm5.5 rc5

* add missing bracket

* Disable SkipLDS & Align AIT api

* Update dependabot config (#682)
Co-authored-by: default avatarsamjwu <samjwu@users.noreply.github.com>

* update attn api

* solve type_convert bug + enable

---------
Co-authored-by: default avatarSam Wu <sjwu@ualberta.ca>
Co-authored-by: default avatarSam Wu <sam.wu2@amd.com>
Co-authored-by: default avatarrocking5566 <ChunYu.Lai@amd.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarRostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: default avatarRosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarroot <root@ctr-ubbsmc15.amd.com>
Co-authored-by: default avatarJun Liu <Liu.Jun@amd.com>
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: default avatarsamjwu <samjwu@users.noreply.github.com>
Co-authored-by: default avatarhaocwang <Haocong.WANG@amd.com>

* fix typo

* Fix attention with causal mask

* multiple fix, try ait compile

* Add A/B not use LDS pipeline

* Clang format, Add gfx1101, gfx1102 support of FMHA example

* cancel change of format script

* 1. Enable 2-stage global Prefetch ( May cause VGPR spilling)
2. Enable FP16 accumulator blockwise_gemm

* clang-format

* 1. change blockwise gemm loopover direction from kmn to mnk ( ~1% improvement)
2. change kernel timing mode to 50 warmup + 50 timed repeat

* Update low level abstration of blockwise gemm wmma

* (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds

* (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds

* (4/5) grouped conv pass

* (5/5) attention pass, todo: debug lds perf bug

* AIT Attention API refactor (#8)

* sanity pass

* sanity pass 2

* confirm significant performance regression.

* turn on all instances

* turn off instance format

* Fix bug & tunning & format

* DML meta, self_attn+cross_attn

* sanity pass

* remove useless flag

* update tile and problem size used in AIT attention

* bug fix in grouped conv supporting check

* deprecate inline asm wmma

* Bug fix: double lds skip

* clang-format

* Fix errors in
1. example, fmha
2. gridwise pipeline
3. deviceop, fmha, change some containers from vector to array

* part2 of previous commit

* clang format

* API fix of gridwisegemmpipeline

* separate array base and vector base attention tensor transformation

* fix gemm

* clang format

* add gemm fp16 instances

* Temp save

* fpAintB kernel compile pass

* Sanity pass.

* Temp save

* debug code enabled

* Fp16AInt8B_GEMM sanity

* MQA implementation

* GQA-4 example

* tempsave

* Compile pass

* New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm

* format

* Todo: fix gemm_bilinear_wmma instances compilation bug

* Solve a bug when K1=16

* remove unnecessary changes

* Remove tensor layout limitation to LDS usage in tesnor contraction

* update self-attention and cross-attention

* fix a typo of name

* Add arch limiter for fp8 gemm

* enable fp8 gemm_xdl for all gfx9 targets

* temporarily disable gemm_xdl_fp16_fp8 on MI100/200

* fix the cmake logic for gemm_xdl_fp16_fp8

* re-enable the gemm_xdl_fp16_fp8 on MI100/200

---------
Co-authored-by: default avataraska-0096 <haocwang@amd.com>
Co-authored-by: default avatarSam Wu <sjwu@ualberta.ca>
Co-authored-by: default avatarSam Wu <sam.wu2@amd.com>
Co-authored-by: default avatarrocking5566 <ChunYu.Lai@amd.com>
Co-authored-by: default avatarRostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: default avatarRosty Geyyer <rosty.geyyer@amd.com>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarroot <root@ctr-ubbsmc15.amd.com>
Co-authored-by: default avatarJun Liu <Liu.Jun@amd.com>
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: default avatarsamjwu <samjwu@users.noreply.github.com>
Co-authored-by: default avatarhaocwang <Haocong.WANG@amd.com>
Co-authored-by: default avatarillsilin <Illia.Silin@amd.com>
parent 363feb48
if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_fpAintB_gemm_wmma)
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp"
struct ProblemSize final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <typename IntType>
struct UnsignedWeightPreprocessor
{
};
template <>
struct UnsignedWeightPreprocessor<int8_t>
{
using UnsignedWeight = Tensor<uint8_t>;
using SignedWeight = Tensor<int8_t>;
static UnsignedWeight convert(SignedWeight const& Input)
{
UnsignedWeight Output = Input.template CopyAsType<uint8_t>();
auto f_kn = [&](auto k, auto n) {
const uint8_t adder = 128;
int8_t v_signed_weight;
uint8_t v_unsigned_weight;
ck::tensor_operation::element_wise::PassThrough{}(v_signed_weight, Input(k, n));
v_unsigned_weight = ck::type_convert<uint8_t>(v_signed_weight) + adder;
Output(k, n) = v_unsigned_weight;
};
make_ParallelTensorFunctor(f_kn, Input.mDesc.GetLengths()[0], Input.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return Output;
}
UnsignedWeight operator()(SignedWeight const& Input) { return convert(Input); }
};
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideC = std::stoi(argv[9]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp"
// Implementation follows the paper:
// Kim, Young Jin, Rawn Henry, Raffy Fahim, and Hany Hassan Awadalla. “Who Says Elephants Can’t Run:
// Bringing Large Scale MoE Models into Cloud Scale Production.” arXiv, November 17, 2022.
// https://doi.org/10.48550/arXiv.2211.10017. Assume weight (Matrix B) is add preprocess to
// unsigned.
// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType
// The HostRef is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType
// TODO: Current implementation consume more VGPR than expected.
using ADataType = ck::half_t;
using QuantDataType = int8_t;
using BDataType = uint8_t;
using ScaleDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_CShuffle
< ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
ScaleDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferencefpAintBGemm<ADataType,
QuantDataType,
ScaleDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<QuantDataType> quant_b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
// assume scale tensor is [1, n]
Tensor<ScaleDataType> scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<QuantDataType>{-1.f, 1.f}(quant_b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{-1.f, 1.f}(scale_k_n);
break;
case 2:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<QuantDataType>{-1.f, 1.f}(quant_b_k_n);
ck::utils::FillUniformDistribution<ScaleDataType>{-1.f, 1.f}(scale_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<QuantDataType>{-1.f, 1.f}(quant_b_k_n);
ck::utils::FillUniformDistribution<ScaleDataType>{-1.f, 1.f}(scale_k_n);
}
UnsignedWeightPreprocessor<QuantDataType> preprocessor;
Tensor<BDataType> b_k_n = preprocessor(quant_b_k_n);
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "scale_k_n: " << scale_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem scale_k_n_device_buf(sizeof(ScaleDataType) * scale_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
scale_k_n_device_buf.ToDevice(scale_k_n.mData.data());
#endif
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<ScaleDataType*>(scale_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
quant_b_k_n,
scale_k_n,
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
return true;
}
bool run_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp"
namespace ck {
/**
* @brief Blockwise data transfer with dequantization
*
* RunRead would load low-precision data and scale data.
* RunWrite would process dequantization process.
* Assume Scale is identical along K-dimension
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template <typename ThreadGroup,
typename SrcElementwiseOperation,
typename ScaleElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths,
typename BlockScaleSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename ScaleData,
typename DstData,
typename SrcDesc,
typename ScaleDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t ScaleScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t ScaleScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1_dequant
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr auto scale_thread_slice_lengths =
BlockScaleSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const ScaleDesc& scale_desc,
const Index& scale_block_slice_origin,
const ScaleElementwiseOperation& scale_element_op,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,
scale_desc,
make_zero_multi_index<nDim>(),
scale_element_op,
dst_desc,
make_zero_multi_index<nDim>(),
dst_element_op)
{
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<ScaleDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{} &&
is_same<BlockScaleSliceLengths,
decltype(scale_thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetScaleSliceOrigin(
scale_desc, scale_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
}
// With the assumption, scale scratch is always one
template <typename ScaleBuffer>
__device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunScaleRead(scale_desc, scale_buf);
}
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
}
// We don't prefer use this API directly
/*
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
*/
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
// With the assumption, scale buffer don't need move slice window method
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1_dequant<decltype(thread_slice_lengths),
decltype(scale_thread_slice_lengths),
SrcElementwiseOperation,
ScaleElementwiseOperation,
DstElementwiseOperation,
DstInMemOp,
SrcData,
ScaleData,
DstData,
SrcDesc,
ScaleDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
ScaleScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
ScaleScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Dequantization of input tensor could not be decoupled from gridwisegemm pipeline
// As input tensor thread buffer declared inside blockwise-gemm pipeline.
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm_dequantB : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_scale,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -196,7 +196,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
using EGridDesc_M_N = remove_cvref_t<tuple_element_t<3, ABDsEGridDesc>>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
using GridwiseGemm = GridwiseGemmMultipleD_Wmma<
// DataType Family
ADataType,
BDataType,
......@@ -217,7 +217,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
// Tiling Family
MPerBlock,
NPerBlock,
K0PerBlock,
KPerBlock,
MPerWMMA,
NPerWMMA,
K1,
......@@ -232,6 +232,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
......@@ -240,6 +241,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
true,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
......
......@@ -393,12 +393,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
using CShuffleDataType = AccDataType;
using GridwiseGemm = GridwiseGemmMultipleD_Wmma<
// DataType Family
ADataType,
BDataType,
AccDataType,
CDataType,
CShuffleDataType,
Tuple<>,
CDataType,
// InMemory Data Descriptor
......@@ -414,7 +416,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
// Tiling Family
MPerBlock,
NPerBlock,
K0PerBlock,
KPerBlock,
MPerWMMA,
NPerWMMA,
K1,
......@@ -429,6 +431,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
true,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
......@@ -437,6 +440,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
true,
BBlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
......
......@@ -53,7 +53,10 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ C0MatrixMask_impl(index_t NRaw)
: NRaw_(NRaw), predicate_(MaskOutPredicate{})
{
}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
......
......@@ -123,6 +123,12 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
......@@ -663,6 +669,76 @@ struct Elu
const float alpha_;
};
// support fastconvert of int8 to fp16
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter
{
};
template <>
struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
{
using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck::half_t, 4>;
__device__ static OutputArray convert(InputArray const& Input)
{
OutputArray Output;
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
static constexpr uint32_t byte_selector_01 = 0x05010500;
static constexpr uint32_t byte_selector_23 = 0x05030502;
static constexpr uint32_t fp16_adder = 0x64646464;
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
return Output;
}
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck::half_t, N>;
__device__ static OutputArray convert(InputArray const& Input)
{
FastNumericArrayConverter<uint8_t, ck::half_t, 4> converter;
OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck::half_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
static_for<0, N / VEC_WIDTH, 1>{}(
[&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
return Output;
}
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -116,7 +116,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage, true, true>;
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
static constexpr auto MakeD0sGridPointer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment