Commit c2945b96 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

add reviewers comments

parent aa30ef56
......@@ -6,22 +6,6 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
template float gemm_<trait_<ck_tile::half_t,
ck_tile::half_t,
float,
ck_tile::half_t,
Row,
Col,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(const A&, const S&);
// clang-format off
template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on
......@@ -2,50 +2,11 @@
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include <iostream>
#include "gemm_basic.hpp"
#include "gemm.hpp"
using A = ck_tile::GemmHostArgs;
using S = ck_tile::stream_config;
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
ck_tile::index_t M_Tile_,
ck_tile::index_t N_Tile_,
ck_tile::index_t K_Tile_,
ck_tile::index_t M_Warp_,
ck_tile::index_t N_Warp_,
ck_tile::index_t K_Warp_,
ck_tile::index_t M_Warp_Tile_,
ck_tile::index_t N_Warp_Tile_,
ck_tile::index_t K_Warp_Tile_,
bool kPadM_,
bool kPadN_,
bool kPadK_>
using trait_ = gemm_traits_<ADataType_,
BDataType_,
AccDataType_,
CDataType_,
ALayout_,
BLayout_,
CLayout_,
M_Tile_,
N_Tile_,
K_Tile_,
M_Warp_,
N_Warp_,
K_Warp_,
M_Warp_Tile_,
N_Warp_Tile_,
K_Warp_Tile_,
kPadM_,
kPadN_,
kPadK_>;
template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
......
......@@ -2,6 +2,29 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
......@@ -28,6 +51,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B;
args.stride_C = stride_C;
// TODO: Change datatypes in future to allow mixed precision gemms!
gemm_traits traits{DataTypeTraits<ADataType>{}.name,
std::is_same_v<ALayout, Row>,
std::is_same_v<BLayout, Row>,
......
......@@ -10,30 +10,7 @@
#include <tuple>
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
#include "gemm.hpp"
#include "run_gemm_example.inc"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment