Commit 37febb8d authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

pre-commited missing files

parent 15d96340
...@@ -34,8 +34,7 @@ __global__ void ...@@ -34,8 +34,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk( kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -226,7 +226,8 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification, ...@@ -226,7 +226,8 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
std::string gemm_name = gemm_ptr->GetTypeString(); std::string gemm_name = gemm_ptr->GetTypeString();
using DeviceOpSplitK = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitK<ALayout, using DeviceOpSplitK =
ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitK<ALayout,
BLayout, BLayout,
ck::Tuple<>, ck::Tuple<>,
CLayout, CLayout,
...@@ -258,8 +259,10 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification, ...@@ -258,8 +259,10 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetKBatchSize(argument_ptr.get(), kbatch_curr); ->SetKBatchSize(argument_ptr.get(), kbatch_curr);
DeviceMem gemm_arg_dev_mem(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())->GetDeviceKernelArgSize(argument_ptr.get())); DeviceMem gemm_arg_dev_mem(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer()); ->GetDeviceKernelArgSize(argument_ptr.get()));
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer());
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment