"host/vscode:/vscode.git/clone" did not exist on "5ce317cbfff71f1f03c90972ba4fbcb5b7d253d9"
Commit 68b71534 authored by Anthony Chang's avatar Anthony Chang
Browse files

set up example code

parent 89a5e847
......@@ -6,3 +6,4 @@ add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_example_executable(example_gemm_gemm_xdl_fp16 gemm_gemm_xdl_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|------------|
Gemm0
|---------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
......@@ -8,7 +16,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -29,13 +37,15 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
......@@ -45,31 +55,45 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout,B0Layout, CLayout, ADataType,B0DataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceGemm0Instance = ck::tensor_operation::host::
ReferenceGemm<ADataType, B0DataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::
ReferenceGemm<AccDataType, B1DataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
// int init_method = 1;
int init_method = 3;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
// ck::index_t M = 1024;
// ck::index_t N = 1024;
// ck::index_t K = 64;
// ck::index_t O = 64;
// ck::index_t StrideA = 1024;
// ck::index_t StrideB0 = 1024;
// ck::index_t StrideB1 = 1024;
// ck::index_t StrideC = 1024;
ck::index_t M = 256;
ck::index_t N = 256;
ck::index_t K = 32;
ck::index_t O = 256;
ck::index_t StrideA = 256;
ck::index_t StrideB0 = 256;
ck::index_t StrideB1 = 256;
ck::index_t StrideC = 256;
if(argc == 1)
{
......@@ -81,7 +105,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
else if(argc == 12)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -90,10 +114,12 @@ int main(int argc, char* argv[])
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
StrideA = std::stoi(argv[8]);
StrideB0 = std::stoi(argv[9]);
StrideB1 = std::stoi(argv[10]);
StrideC = std::stoi(argv[11]);
}
else
{
......@@ -118,37 +144,44 @@ int main(int argc, char* argv[])
}
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB0, B0Layout{}));
Tensor<B1DataType> b1_n_o(f_host_tensor_descriptor(N, O, StrideB1, B1Layout{}));
Tensor<CDataType> c_m_o_host_result(f_host_tensor_descriptor(N, O, StrideC, CLayout{}));
Tensor<CDataType> c_m_o_device_result(f_host_tensor_descriptor(N, O, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b1_n_o: " << b1_n_o.mDesc << std::endl;
std::cout << "c_m_o: " << c_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b0_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem b0_k_n_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpace());
DeviceMem b1_n_o_device_buf(sizeof(B1DataType) * b1_n_o.mDesc.GetElementSpace());
DeviceMem c_m_o_device_buf(sizeof(CDataType) * c_m_o_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
b0_k_n_device_buf.ToDevice(b0_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
......@@ -158,13 +191,13 @@ int main(int argc, char* argv[])
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideB0,
StrideC,
a_element_op,
b_element_op,
......@@ -179,9 +212,9 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
std::size_t flop = std::size_t(2) * (M * N * K + M * N * O);
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......@@ -190,19 +223,28 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
c_m_o_device_buf.FromDevice(c_m_o_device_result.mData.data());
if(do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
// Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> a1_m_n(f_host_tensor_descriptor(M, N, N, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_m_k, b0_k_n, a1_m_n, a_element_op, b_element_op, c_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_m_n, b1_n_o, c_m_o_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
ref_gemm1_invoker.Run(ref_gemm1_argument);
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
return ck::utils::check_err(c_m_o_device_result.mData, c_m_o_host_result.mData) ? 0 : 1;
}
return 0;
......
......@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
......@@ -23,6 +23,7 @@ namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
// Computes C = A * B0 * B1
template <typename ALayout,
typename BLayout,
typename CLayout,
......@@ -65,7 +66,7 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
......@@ -75,7 +76,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceGemm_Xdl_CShuffle;
using DeviceOp = DeviceGemmGemm_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -350,7 +351,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
using GridwiseGemm = GridwiseGemmGemm_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
......@@ -490,7 +491,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
const auto kernel = kernel_gemm_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -522,7 +523,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
const auto kernel = kernel_gemm_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......
......@@ -32,7 +32,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
......@@ -115,7 +115,7 @@ template <typename FloatAB,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
struct GridwiseGemmGemm_xdl_cshuffle_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......
......@@ -151,3 +151,20 @@ struct GeneratorTensor_Sequential
return dims[Dim];
}
};
template <typename T>
struct GeneratorTensor_Diagonal
{
T value{1};
template <typename... Ts>
T operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
bool pred = true;
for (size_t i = 1; i < dims.size(); i++) {
pred &= (dims[0] == dims[i]);
}
return pred ? value : T{0};
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment