Unverified Commit 77fa9fda authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen_hiprtc

parents 760ea189 e7b62864
...@@ -121,7 +121,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -121,7 +121,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
constexpr ck::index_t NumDTensor = 2; constexpr ck::index_t NumDTensor = 2;
using GroupedGemmKernelArgument = using GroupedGemmKernelArgument =
ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDTensor>; ck::tensor_operation::device::GroupedGemmKernelArgument<NumDTensor>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_; std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count); grouped_gemm_kernel_args_.reserve(group_count);
......
...@@ -120,7 +120,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -120,7 +120,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
constexpr ck::index_t NumDTensor = 1; constexpr ck::index_t NumDTensor = 1;
using GroupedGemmKernelArgument = using GroupedGemmKernelArgument =
ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDTensor>; ck::tensor_operation::device::GroupedGemmKernelArgument<NumDTensor>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_; std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count); grouped_gemm_kernel_args_.reserve(group_count);
......
...@@ -7,6 +7,7 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) ...@@ -7,6 +7,7 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h)
find_package(ROCM) find_package(ROCM)
include(ROCMInstallTargets) include(ROCMInstallTargets)
......
...@@ -246,7 +246,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -246,7 +246,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
// do GEMM // do GEMM
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
gemm.SetKBatchSize(argument, config.k_batch); gemm.SetKBatchSize(&argument, config.k_batch);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( throw std::runtime_error(
...@@ -257,7 +257,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -257,7 +257,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false, 1}); invoker.Run(argument, StreamConfig{nullptr, false, 1});
......
...@@ -91,7 +91,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -91,7 +91,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
{ {
auto group_count = problem_size.group_count; auto group_count = problem_size.group_count;
using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDs>; using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
using GemmDesc = ck::tensor_operation::device::GemmDesc; using GemmDesc = ck::tensor_operation::device::GemmDesc;
// GEMM shape // GEMM shape
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -254,7 +254,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -254,7 +254,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm.GetDeviceKernelArgSize(&argument), gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); gemm.SetDeviceKernelArgs(&argument, gemm_kernel_args_dev.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch); gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -239,7 +239,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -239,7 +239,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem"); "not support this GEMM problem");
} }
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch); gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -240,7 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -240,7 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem"); "not support this GEMM problem");
} }
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch); gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
......
...@@ -168,9 +168,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -168,9 +168,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument);
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); DeviceMem gemm_workspace, gemm_kargs;
// The following is necessary since TwoStage kernel is using additional memory both
// for Workspace and kernel arguments.
if(kargs_size > 0)
{
gemm_kargs.Realloc(kargs_size);
gemm.SetDeviceKernelArgs(&argument, gemm_kargs.GetDeviceBuffer());
}
if(workspace_size > 0 && workspace_size != kargs_size)
{
gemm_workspace.Realloc(workspace_size);
gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer());
}
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
......
...@@ -30,7 +30,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -30,7 +30,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
#else #else
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 256;
...@@ -84,7 +83,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -84,7 +83,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
AccDataType, AccDataType,
GemmShape, GemmShape,
Traits, Traits,
ck_tile::GemmPipelineScheduler::Intrawave, ck_tile::GemmPipelineScheduler::Interwave,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
......
...@@ -200,7 +200,8 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -200,7 +200,8 @@ int run_gemm_example(int argc, char* argv[])
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work. else if(a_layout == "C" && b_layout == "C") // work.
// else if(a_layout == "C" && b_layout == "C")
// { // {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// } // }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment