"test/vscode:/vscode.git/clone" did not exist on "7dff24cae4b40d858439dbfd5b2d8be72b57ab70"
Unverified Commit 9488f1c9 authored by rtmadduri's avatar rtmadduri Committed by GitHub
Browse files

LWPCK-2429: Device grouped GEMM uses Async Memcpy (#1695)



* LWPCK-2429: Device grouped GEMM uses Async Memcpy
Resolving merge conflicts

* reverting changes to profile_grouped_gemm

* revert date change

---------
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
parent 44828b7c
#pragma once #pragma once
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -603,11 +603,11 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -603,11 +603,11 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
} }
hipGetErrorString( hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
auto launch_kernel = [&](auto has_main_k_block_loop, auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) { auto has_double_tail_k_block_loop) {
......
...@@ -761,11 +761,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -761,11 +761,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float time{0.f}; float time{0.f};
hip_check_error( hip_check_error(
hipMemcpyWithStream(dev_gemm_kargs, hipMemcpyAsync(dev_gemm_kargs,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
auto preprocess = [&]() { auto preprocess = [&]() {
hip_check_error(hipMemsetAsync( hip_check_error(hipMemsetAsync(
......
...@@ -940,10 +940,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -940,10 +940,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const void* p_host_kernel_args) const const void* p_host_kernel_args) const
{ {
arg.p_dev_gemm_args_ = p_dev_kernel_args; arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args, hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
p_host_kernel_args, p_host_kernel_args,
GetDeviceKernelArgSize(&arg), GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
} }
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
......
...@@ -557,12 +557,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -557,12 +557,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
} }
} }
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_, hipGetErrorString(
arg.gemm_desc_kernel_arg_.data(), hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.size() * arg.gemm_desc_kernel_arg_.data(),
sizeof(GemmBiasTransKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
...@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
hip_check_error( hip_check_error(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
float ave_time = 0; float ave_time = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment