Unverified Commit b8d11559 authored by amd-khushbu's avatar amd-khushbu Committed by GitHub
Browse files

Merge branch 'develop' into ck_profiler_m_instances

parents 7f3fe4e7 3b230208
...@@ -92,6 +92,7 @@ endif() ...@@ -92,6 +92,7 @@ endif()
add_compile_options(-Wno-bit-int-extension) add_compile_options(-Wno-bit-int-extension)
add_compile_options(-Wno-pass-failed) add_compile_options(-Wno-pass-failed)
add_compile_options(-Wno-switch-default) add_compile_options(-Wno-switch-default)
add_compile_options(-Wno-unique-object-duplication)
if(DL_KERNELS) if(DL_KERNELS)
add_definitions(-DDL_KERNELS) add_definitions(-DDL_KERNELS)
......
...@@ -117,7 +117,7 @@ def getDockerImage(Map conf=[:]){ ...@@ -117,7 +117,7 @@ def getDockerImage(Map conf=[:]){
{ {
echo "Pulling down image: ${image}" echo "Pulling down image: ${image}"
retimage = docker.image("${image}") retimage = docker.image("${image}")
withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
retimage.pull() retimage.pull()
} }
} }
...@@ -148,7 +148,7 @@ def buildDocker(install_prefix){ ...@@ -148,7 +148,7 @@ def buildDocker(install_prefix){
//force building the new docker if that parameter is true //force building the new docker if that parameter is true
echo "Building image: ${image_name}" echo "Building image: ${image_name}"
retimage = docker.build("${image_name}", dockerArgs) retimage = docker.build("${image_name}", dockerArgs)
withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
retimage.push() retimage.push()
} }
sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi' sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi'
...@@ -162,7 +162,7 @@ def buildDocker(install_prefix){ ...@@ -162,7 +162,7 @@ def buildDocker(install_prefix){
catch(Exception ex){ catch(Exception ex){
echo "Unable to locate image: ${image_name}. Building image now" echo "Unable to locate image: ${image_name}. Building image now"
retimage = docker.build("${image_name}", dockerArgs + ' .') retimage = docker.build("${image_name}", dockerArgs + ' .')
withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
retimage.push() retimage.push()
} }
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// defines all values need for an instance of fwd conv
struct Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static std::vector<std::vector<Operation_Xdl_CShuffle>>
CreateOperations(const std::string& prologue, const std::string& epilogue);
// returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{};
TensorDesc B{};
TensorDesc B1{};
TensorDesc C{};
DataType acc = DataType::Float;
DataType cs_type = DataType::Half;
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string b1_elem_op = PassThrough;
std::string c_elem_op = PassThrough;
std::string acc_elem_op = Scale;
std::string prologue = "";
std::string epilogue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
// tuning parameters
operation::TileDescGemmGemm tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b0_block_transfer{};
operation::BlockTransferDesc b1_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
bool mask_out_upper_triangle = false;
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool
IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_);
// returns a templated instance
Solution ToSolution() const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// defines the problem specification for a GEMM operation
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string B1ElementOp = PassThrough;
std::string CElementOp = PassThrough;
std::string AccElementOp = Scale;
// returns the correct device op file for the operation
std::string GetIncludeHeader() const;
// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
...@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle ...@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
operation::BlockTransferDesc b_block_transfer{}; operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{}; operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{}; operation::CBlockTransferDesc c_block_transfer{};
LoopScheduler loop_scheduler{};
PipelineVersion pipeline_version{};
// functions to update fusion operators if provided // functions to update fusion operators if provided
void update_prologue(const std::string& prologue); void update_prologue(const std::string& prologue);
......
...@@ -23,6 +23,26 @@ struct TileDesc ...@@ -23,6 +23,26 @@ struct TileDesc
int n_Xdl_per_wave = 0; int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0; int num_gemmk_prefetch_stage = 0;
}; };
struct TileDescGemmGemm
{
int block_size = 0;
int gemm01_m_per_block = 0;
int gemm0_n_per_block = 0;
int gemm0_k_per_block = 0;
int gemm1_n_per_block = 0;
int gemm1_k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int b1k1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int gemm0_m_Xdl_per_wave = 0;
int gemm0_n_Xdl_per_wave = 0;
int gemm1_n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
struct BlockTransferDesc struct BlockTransferDesc
{ {
std::string thread_cluster_length = ""; std::string thread_cluster_length = "";
......
...@@ -66,6 +66,20 @@ enum class GemmType ...@@ -66,6 +66,20 @@ enum class GemmType
}; };
std::string ToString(GemmType gt); std::string ToString(GemmType gt);
enum class LoopScheduler
{
Default,
Interwave,
};
std::string ToString(LoopScheduler ls);
enum class PipelineVersion
{
v1,
v2
};
std::string ToString(PipelineVersion pv);
struct TensorDesc struct TensorDesc
{ {
DataType element; DataType element;
...@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...}); ...@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale";
} // namespace host } // namespace host
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp";
}
// returns templated instances when provided with a problem specification
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations(
*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values
});
return result;
}
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
...@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) ...@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major // accounts for all possible combinations of Row/Col major
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
// clang-format off
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// clang-format on
// Hard-code tuning parameters in modularized fashion, string them together into a vector of // Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances // instances
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
// Irregular tile
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
// clang-format on // clang-format on
}; };
...@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on // clang-format on
}; };
...@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| // ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | // Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | | // | | | | | | |
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on // clang-format on
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
}; };
std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = { std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = {
...@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on // clang-format on
}; };
...@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on // clang-format on
}; };
...@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 1, 1}, { 1, 1},
{ 1, 1}, { 1, 1},
{ 1, 1}, { 1, 1},
{ 1, 1},
{ 1, 1}, { 1, 1},
// clang-format on // clang-format on
}; };
...@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<1, 16, 1, 8>, 8}, { S<1, 16, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8},
// Irregular tile
{ S<1, 16, 1, 4>, 1},
// clang-format on // clang-format on
}; };
...@@ -199,6 +218,14 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -199,6 +218,14 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size());
const std::vector<std::tuple<LoopScheduler, PipelineVersion>> scheduler_pipeline_descriptions =
{
{LoopScheduler::Default, PipelineVersion::v1},
{LoopScheduler::Interwave, PipelineVersion::v1},
{LoopScheduler::Default, PipelineVersion::v2},
};
for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions)
{
// Put all values together into a single operation > store into the result vector // Put all values together into a single operation > store into the result vector
for(std::size_t i = 0; i < tile_descriptions.size(); i++) for(std::size_t i = 0; i < tile_descriptions.size(); i++)
{ {
...@@ -223,10 +250,13 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -223,10 +250,13 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
x.tile_desc.m_per_block, x.tile_desc.m_per_block,
x.tile_desc.n_per_block, x.tile_desc.n_per_block,
x.tile_desc.k_per_block); x.tile_desc.k_per_block);
x.loop_scheduler = loop_scheduler;
x.pipeline_version = pipeline_version;
x.update_prologue(prologue); x.update_prologue(prologue);
x.update_epilogue(epilogue); x.update_epilogue(epilogue);
result.push_back(x); result.push_back(x);
} }
}
return result; return result;
} }
...@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = ...@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}>"; "${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>";
// use hardcoded instances from vector of operations to substitute values into instance template // use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Xdl_CShuffle::ToSolution() const Solution Operation_Xdl_CShuffle::ToSolution() const
...@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CDEBlockTransferScalarPerVector_NPerBlock", {"CDEBlockTransferScalarPerVector_NPerBlock",
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
{"LoopScheduler", ToString(this->loop_scheduler)},
{"PipelineVersion", ToString(this->pipeline_version)},
}; };
return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values),
......
...@@ -59,6 +59,26 @@ std::string ToString(GemmType gt) ...@@ -59,6 +59,26 @@ std::string ToString(GemmType gt)
throw std::runtime_error("Incorrect gemm type"); throw std::runtime_error("Incorrect gemm type");
} }
std::string ToString(LoopScheduler ls)
{
switch(ls)
{
case LoopScheduler::Default: return "ck::LoopScheduler::Default";
case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave";
}
throw std::runtime_error("Incorrect LoopScheduler type");
}
std::string ToString(PipelineVersion pv)
{
switch(pv)
{
case PipelineVersion::v1: return "ck::PipelineVersion::v1";
case PipelineVersion::v2: return "ck::PipelineVersion::v2";
}
throw std::runtime_error("Incorrect PipelineVersion type");
}
std::string SequenceStr(const std::vector<int>& v) std::string SequenceStr(const std::vector<int>& v)
{ {
return "ck::Sequence<" + return "ck::Sequence<" +
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <memory> #include <memory>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <stdexcept>
namespace rtc { namespace rtc {
......
...@@ -27,11 +27,15 @@ using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK ...@@ -27,11 +27,15 @@ using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#else // defined(CK_USE_AMD_MFMA_GFX950)
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>;
#endif // defined(CK_USE_AMD_MFMA_GFX950)
......
...@@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= deterministic == "f" cond &= deterministic == "f"
if not cond: if not cond:
continue continue
if receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'bias']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait()) api_pool.register_dq_dk_dv_traits(k.api_trait())
gen.append(k) gen.append(k)
......
...@@ -487,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -487,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
if kernel_filter != None: if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter): if not fnmatch.fnmatch(k.name, kernel_filter):
continue continue
if receipt == 2: if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16'] cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f' cond &= pipeline.F_squant == 'f'
if not cond: if not cond:
continue continue
if receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait()) api_pool.register_traits(k.api_trait())
gen.append(k) gen.append(k)
......
...@@ -103,7 +103,8 @@ if __name__ == "__main__": ...@@ -103,7 +103,8 @@ if __name__ == "__main__":
required=False, required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
" 1: generate more instance to cover all hdim\n" + \ " 1: generate more instance to cover all hdim\n" + \
" 2: Only generate instance for Flash attention integration" " 2: Only generate instance for Flash attention integration\n" + \
" 4: Only generate instance for PyTorch integration"
) )
args = parser.parse_args() args = parser.parse_args()
......
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
target_compile_options(tile_example_gemm_universal PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
...@@ -82,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -82,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl; << std::endl;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -11,21 +11,26 @@ ...@@ -11,21 +11,26 @@
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1 #define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT #ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif #endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else #else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value" #error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif #endif
...@@ -126,7 +131,8 @@ auto create_args(int argc, char* argv[]) ...@@ -126,7 +131,8 @@ auto create_args(int argc, char* argv[])
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value"); .insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K, ...@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType, template <typename ADataType,
typename ALayout, typename BLayout, typename CLayout> typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf,
...@@ -57,8 +62,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -57,8 +62,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B; args.stride_B = stride_B;
args.stride_C = stride_C; args.stride_C = stride_C;
float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType, float ave_time =
ALayout, BLayout, CLayout>( gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
...@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " C_Layout =" << CLayout::name
<< " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name << " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name << " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
<< std::endl;
return ave_time; return ave_time;
} }
...@@ -108,6 +110,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -108,6 +110,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t kbatch = arg_parser.get_int("split_k"); ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
...@@ -120,9 +123,26 @@ int run_gemm_example_with_layouts(int argc, ...@@ -120,9 +123,26 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::HostTensor<CDataType> c_m_n_dev_result( ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
// TODO: add different init types if(init_method == 0)
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k); {
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n); ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
...@@ -133,8 +153,8 @@ int run_gemm_example_with_layouts(int argc, ...@@ -133,8 +153,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
ALayout, BLayout, CLayout>(a_m_k_dev_buf, a_m_k_dev_buf,
b_k_n_dev_buf, b_k_n_dev_buf,
c_m_n_dev_buf, c_m_n_dev_buf,
M, M,
...@@ -160,8 +180,8 @@ int run_gemm_example_with_layouts(int argc, ...@@ -160,8 +180,8 @@ int run_gemm_example_with_layouts(int argc,
a_m_k, b_k_n, c_m_n_host_ref); a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value = const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType> const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
(K, kbatch, max_accumulated_value); K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result, pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref, c_m_n_host_ref,
"Error: Incorrect results!", "Error: Incorrect results!",
...@@ -171,7 +191,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -171,7 +191,7 @@ int run_gemm_example_with_layouts(int argc,
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl; << std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
{ {
...@@ -218,8 +238,8 @@ int run_gemm_example_with_layouts(int argc, ...@@ -218,8 +238,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value = const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType> const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
(K, kbatch, max_accumulated_value); K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result, pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref, c_m_n_gpu_ref,
"Error: Incorrect results!", "Error: Incorrect results!",
...@@ -229,7 +249,7 @@ int run_gemm_example_with_layouts(int argc, ...@@ -229,7 +249,7 @@ int run_gemm_example_with_layouts(int argc,
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl; << std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
} }
return pass; return pass;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment