Unverified Commit 9488f1c9 authored by rtmadduri's avatar rtmadduri Committed by GitHub
Browse files

LWPCK-2429: Device grouped GEMM uses Async Memcpy (#1695)



* LWPCK-2429: Device grouped GEMM uses Async Memcpy
Resolving merge conflicts

* reverting changes to profile_grouped_gemm

* revert date change

---------
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
parent 44828b7c
#pragma once #pragma once
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -603,7 +603,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout, ...@@ -603,7 +603,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
} }
hipGetErrorString( hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
......
...@@ -761,7 +761,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -761,7 +761,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float time{0.f}; float time{0.f};
hip_check_error( hip_check_error(
hipMemcpyWithStream(dev_gemm_kargs, hipMemcpyAsync(dev_gemm_kargs,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
......
...@@ -940,7 +940,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ...@@ -940,7 +940,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const void* p_host_kernel_args) const const void* p_host_kernel_args) const
{ {
arg.p_dev_gemm_args_ = p_dev_kernel_args; arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args, hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
p_host_kernel_args, p_host_kernel_args,
GetDeviceKernelArgSize(&arg), GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
......
...@@ -557,10 +557,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -557,10 +557,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
} }
} }
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_, hipGetErrorString(
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
......
...@@ -421,7 +421,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -421,7 +421,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
hip_check_error( hip_check_error(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment