Commit 536c5458 authored by ThomasNing's avatar ThomasNing
Browse files

fix with better naming convention

parent 04006d5f
......@@ -68,8 +68,8 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
// ===============================================
using Shape = ck_tile::TileGemmShapeNewGemm<
ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
using Shape =
ck_tile::TileGemmShapeNewGemm<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<Shape>;
......@@ -83,9 +83,9 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, Layouts>;
auto kargs = Kernel::MakeKargs(args.p_x,
args.p_y,
args.p_z,
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.batch_size,
args.epsilon,
args.M,
......@@ -105,9 +105,9 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
}
template <typename DataType, typename Layouts>
float OperatorExecution(ck_tile::DeviceMem& x_buf,
ck_tile::DeviceMem& y_buf,
ck_tile::DeviceMem& z_buf,
float OperatorExecution(ck_tile::DeviceMem& a_buf,
ck_tile::DeviceMem& b_buf,
ck_tile::DeviceMem& c_buf,
const ck_tile::ArgParser& arg_parser)
{
......@@ -131,9 +131,9 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf,
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
gemm_basic_args args;
args.p_x = x_buf.GetDeviceBuffer();
args.p_y = y_buf.GetDeviceBuffer();
args.p_z = z_buf.GetDeviceBuffer();
args.p_a = a_buf.GetDeviceBuffer();
args.p_b = b_buf.GetDeviceBuffer();
args.p_c = c_buf.GetDeviceBuffer();
args.epsilon = epsilon;
args.batch_size = batch_size;
args.M = M;
......@@ -222,37 +222,38 @@ int main(int argc, char* argv[])
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
constexpr ck_tile::MatrixALayout matrix_a_layout = ck_tile::MatrixALayout::MK;
constexpr ck_tile::MatrixBLayout matrix_b_layout = ck_tile::MatrixBLayout::NK;
constexpr ck_tile::MatrixCLayout matrix_c_layout = ck_tile::MatrixCLayout::MN;
using Layouts = LayoutConfig<matrix_a_layout, matrix_b_layout, matrix_c_layout>;
// host verify
std::vector<int> x_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK)
std::vector<int> a_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK)
? std::vector<int>{M, K}
: std::vector<int>{K, M};
std::vector<int> y_dimensions = (matrix_b_layout == ck_tile::MatrixBLayout::NK)
std::vector<int> b_dimensions = (matrix_b_layout == ck_tile::MatrixBLayout::NK)
? std::vector<int>{N, K}
: std::vector<int>{K, N};
std::vector<int> z_dimensions = (matrix_c_layout == ck_tile::MatrixCLayout::MN)
std::vector<int> c_dimensions = (matrix_c_layout == ck_tile::MatrixCLayout::MN)
? std::vector<int>{M, N}
: std::vector<int>{N, M};
ck_tile::HostTensor<XDataType> x_host(x_dimensions);
ck_tile::HostTensor<YDataType> y_host(y_dimensions);
ck_tile::HostTensor<XDataType> a_host(a_dimensions);
ck_tile::HostTensor<YDataType> b_host(b_dimensions);
ck_tile::HostTensor<ODataType> z_host_ref(z_dimensions);
ck_tile::HostTensor<ODataType> z_host_dev(z_dimensions);
ck_tile::HostTensor<ODataType> c_host_ref(c_dimensions);
ck_tile::HostTensor<ODataType> c_host_dev(c_dimensions);
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
ck_tile::FillUniformDistribution<YDataType>{-5.f, 5.f}(y_host);
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistribution<YDataType>{-5.f, 5.f}(b_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem z_buf(z_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
y_buf.ToDevice(y_host.data());
a_buf.ToDevice(a_host.data());
b_buf.ToDevice(b_host.data());
if(grouped_enable || following_op_descrp != "no")
{
......@@ -260,7 +261,7 @@ int main(int argc, char* argv[])
return -1;
}
OperatorExecution<ck_tile::half_t, Layouts>(x_buf, y_buf, z_buf, arg_parser);
OperatorExecution<ck_tile::half_t, Layouts>(a_buf, b_buf, c_buf, arg_parser);
bool pass = true;
......@@ -268,11 +269,11 @@ int main(int argc, char* argv[])
{
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile::reference_gemm<XDataType, YDataType, AccDataType, ODataType>(
x_host, y_host, z_host_ref, matrix_a_layout);
a_host, b_host, c_host_ref, matrix_a_layout);
z_buf.FromDevice(z_host_dev.data());
c_buf.FromDevice(c_host_dev.data());
pass = ck_tile::check_err(z_host_dev, z_host_ref);
pass = ck_tile::check_err(c_host_dev, c_host_ref);
std::cout << "The veification result is:" << (pass ? "correct" : "fail") << std::flush;
}
......
......@@ -62,9 +62,9 @@ using ODataType = Types::ODataType;
struct gemm_basic_args
{
const void* p_x;
const void* p_y;
void* p_z;
const void* p_a;
const void* p_b;
void* p_c;
float epsilon;
ck_tile::index_t batch_size;
ck_tile::index_t M;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment