"vscode:/vscode.git/clone" did not exist on "1f5f17c5b4824c877ca61ed7955757c2c204c6e7"
Commit 37febb8d authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

pre-commited missing files

parent 15d96340
...@@ -34,12 +34,11 @@ __global__ void ...@@ -34,12 +34,11 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk( kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, const index_t group_count,
const index_t group_count, const AElementwiseOperation a_element_op,
const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op,
const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__)) defined(__gfx94__))
...@@ -206,7 +205,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -206,7 +205,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
static constexpr index_t B2E_M01 = 8; static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument; using KernelArgument = typename GridwiseGemm::Argument;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
KernelArgument karg_; KernelArgument karg_;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -34,18 +34,18 @@ template <typename ADataType, ...@@ -34,18 +34,18 @@ template <typename ADataType,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
bool profile_grouped_gemm_two_stage_impl(int do_verification, bool profile_grouped_gemm_two_stage_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
const std::vector<int>& Ms, const std::vector<int>& Ms,
const std::vector<int>& Ns, const std::vector<int>& Ns,
const std::vector<int>& Ks, const std::vector<int>& Ks,
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs,
int kbatch = 1, int kbatch = 1,
int n_warmup = 1, int n_warmup = 1,
int n_iter = 10) int n_iter = 10)
{ {
bool pass = true; bool pass = true;
...@@ -226,17 +226,18 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification, ...@@ -226,17 +226,18 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
std::string gemm_name = gemm_ptr->GetTypeString(); std::string gemm_name = gemm_ptr->GetTypeString();
using DeviceOpSplitK = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitK<ALayout, using DeviceOpSplitK =
BLayout, ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitK<ALayout,
ck::Tuple<>, BLayout,
CLayout, ck::Tuple<>,
ADataType, CLayout,
BDataType, ADataType,
ck::Tuple<>, BDataType,
CDataType, ck::Tuple<>,
AElementOp, CDataType,
BElementOp, AElementOp,
CElementOp>; BElementOp,
CElementOp>;
// skip non-splitk grouped_gemm // skip non-splitk grouped_gemm
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) == nullptr) if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) == nullptr)
...@@ -258,8 +259,10 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification, ...@@ -258,8 +259,10 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetKBatchSize(argument_ptr.get(), kbatch_curr); ->SetKBatchSize(argument_ptr.get(), kbatch_curr);
DeviceMem gemm_arg_dev_mem(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())->GetDeviceKernelArgSize(argument_ptr.get())); DeviceMem gemm_arg_dev_mem(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer()); ->GetDeviceKernelArgSize(argument_ptr.get()));
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer());
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment