Commit 536c5458 authored by ThomasNing's avatar ThomasNing
Browse files

fix with better naming convention

parent 04006d5f
...@@ -68,8 +68,8 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -68,8 +68,8 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
// =============================================== // ===============================================
using Shape = ck_tile::TileGemmShapeNewGemm< using Shape =
ck_tile::sequence<M_Tile, N_Tile, K_Tile>, ck_tile::TileGemmShapeNewGemm<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<Shape>; using TilePartitioner = ck_tile::GemmTilePartitioner<Shape>;
...@@ -83,9 +83,9 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -83,9 +83,9 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, Layouts>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, Layouts>;
auto kargs = Kernel::MakeKargs(args.p_x, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_y, args.p_b,
args.p_z, args.p_c,
args.batch_size, args.batch_size,
args.epsilon, args.epsilon,
args.M, args.M,
...@@ -105,9 +105,9 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -105,9 +105,9 @@ float gemm_calc(gemm_basic_args& args, const ck_tile::stream_config& s)
} }
template <typename DataType, typename Layouts> template <typename DataType, typename Layouts>
float OperatorExecution(ck_tile::DeviceMem& x_buf, float OperatorExecution(ck_tile::DeviceMem& a_buf,
ck_tile::DeviceMem& y_buf, ck_tile::DeviceMem& b_buf,
ck_tile::DeviceMem& z_buf, ck_tile::DeviceMem& c_buf,
const ck_tile::ArgParser& arg_parser) const ck_tile::ArgParser& arg_parser)
{ {
...@@ -131,9 +131,9 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf, ...@@ -131,9 +131,9 @@ float OperatorExecution(ck_tile::DeviceMem& x_buf,
ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
gemm_basic_args args; gemm_basic_args args;
args.p_x = x_buf.GetDeviceBuffer(); args.p_a = a_buf.GetDeviceBuffer();
args.p_y = y_buf.GetDeviceBuffer(); args.p_b = b_buf.GetDeviceBuffer();
args.p_z = z_buf.GetDeviceBuffer(); args.p_c = c_buf.GetDeviceBuffer();
args.epsilon = epsilon; args.epsilon = epsilon;
args.batch_size = batch_size; args.batch_size = batch_size;
args.M = M; args.M = M;
...@@ -222,37 +222,38 @@ int main(int argc, char* argv[]) ...@@ -222,37 +222,38 @@ int main(int argc, char* argv[])
ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k"); ck_tile::index_t K = arg_parser.get_int("k");
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
constexpr ck_tile::MatrixALayout matrix_a_layout = ck_tile::MatrixALayout::MK; constexpr ck_tile::MatrixALayout matrix_a_layout = ck_tile::MatrixALayout::MK;
constexpr ck_tile::MatrixBLayout matrix_b_layout = ck_tile::MatrixBLayout::NK; constexpr ck_tile::MatrixBLayout matrix_b_layout = ck_tile::MatrixBLayout::NK;
constexpr ck_tile::MatrixCLayout matrix_c_layout = ck_tile::MatrixCLayout::MN; constexpr ck_tile::MatrixCLayout matrix_c_layout = ck_tile::MatrixCLayout::MN;
using Layouts = LayoutConfig<matrix_a_layout, matrix_b_layout, matrix_c_layout>; using Layouts = LayoutConfig<matrix_a_layout, matrix_b_layout, matrix_c_layout>;
// host verify // host verify
std::vector<int> x_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK) std::vector<int> a_dimensions = (matrix_a_layout == ck_tile::MatrixALayout::MK)
? std::vector<int>{M, K} ? std::vector<int>{M, K}
: std::vector<int>{K, M}; : std::vector<int>{K, M};
std::vector<int> y_dimensions = (matrix_b_layout == ck_tile::MatrixBLayout::NK) std::vector<int> b_dimensions = (matrix_b_layout == ck_tile::MatrixBLayout::NK)
? std::vector<int>{N, K} ? std::vector<int>{N, K}
: std::vector<int>{K, N}; : std::vector<int>{K, N};
std::vector<int> z_dimensions = (matrix_c_layout == ck_tile::MatrixCLayout::MN) std::vector<int> c_dimensions = (matrix_c_layout == ck_tile::MatrixCLayout::MN)
? std::vector<int>{M, N} ? std::vector<int>{M, N}
: std::vector<int>{N, M}; : std::vector<int>{N, M};
ck_tile::HostTensor<XDataType> x_host(x_dimensions); ck_tile::HostTensor<XDataType> a_host(a_dimensions);
ck_tile::HostTensor<YDataType> y_host(y_dimensions); ck_tile::HostTensor<YDataType> b_host(b_dimensions);
ck_tile::HostTensor<ODataType> z_host_ref(z_dimensions); ck_tile::HostTensor<ODataType> c_host_ref(c_dimensions);
ck_tile::HostTensor<ODataType> z_host_dev(z_dimensions); ck_tile::HostTensor<ODataType> c_host_dev(c_dimensions);
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistribution<YDataType>{-5.f, 5.f}(y_host); ck_tile::FillUniformDistribution<YDataType>{-5.f, 5.f}(b_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem z_buf(z_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); a_buf.ToDevice(a_host.data());
y_buf.ToDevice(y_host.data()); b_buf.ToDevice(b_host.data());
if(grouped_enable || following_op_descrp != "no") if(grouped_enable || following_op_descrp != "no")
{ {
...@@ -260,7 +261,7 @@ int main(int argc, char* argv[]) ...@@ -260,7 +261,7 @@ int main(int argc, char* argv[])
return -1; return -1;
} }
OperatorExecution<ck_tile::half_t, Layouts>(x_buf, y_buf, z_buf, arg_parser); OperatorExecution<ck_tile::half_t, Layouts>(a_buf, b_buf, c_buf, arg_parser);
bool pass = true; bool pass = true;
...@@ -268,11 +269,11 @@ int main(int argc, char* argv[]) ...@@ -268,11 +269,11 @@ int main(int argc, char* argv[])
{ {
// ToDo: Will Add the Element Op (bias) verification in the future. // ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile::reference_gemm<XDataType, YDataType, AccDataType, ODataType>( ck_tile::reference_gemm<XDataType, YDataType, AccDataType, ODataType>(
x_host, y_host, z_host_ref, matrix_a_layout); a_host, b_host, c_host_ref, matrix_a_layout);
z_buf.FromDevice(z_host_dev.data()); c_buf.FromDevice(c_host_dev.data());
pass = ck_tile::check_err(z_host_dev, z_host_ref); pass = ck_tile::check_err(c_host_dev, c_host_ref);
std::cout << "The veification result is:" << (pass ? "correct" : "fail") << std::flush; std::cout << "The veification result is:" << (pass ? "correct" : "fail") << std::flush;
} }
......
...@@ -62,9 +62,9 @@ using ODataType = Types::ODataType; ...@@ -62,9 +62,9 @@ using ODataType = Types::ODataType;
struct gemm_basic_args struct gemm_basic_args
{ {
const void* p_x; const void* p_a;
const void* p_y; const void* p_b;
void* p_z; void* p_c;
float epsilon; float epsilon;
ck_tile::index_t batch_size; ck_tile::index_t batch_size;
ck_tile::index_t M; ck_tile::index_t M;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment