Commit 2db781e9 authored by Ville Pietilä's avatar Ville Pietilä
Browse files

Merge remote-tracking branch 'origin/rimadduri/grouped_gemm_async_memcpy' into...

Merge remote-tracking branch 'origin/rimadduri/grouped_gemm_async_memcpy' into vpietila/ggemm-profiling
parents 578f1d27 f9466a75
rocm-docs-core==1.9.1
rocm-docs-core==1.9.2
sphinxcontrib-bibtex==2.6.3
......@@ -103,7 +103,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.9.1
rocm-docs-core==1.9.2
# via -r requirements.in
six==1.16.0
# via pybtex
......
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -603,11 +603,11 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
......
......@@ -761,11 +761,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float time{0.f};
hip_check_error(
hipMemcpyWithStream(dev_gemm_kargs,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipMemcpyAsync(dev_gemm_kargs,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
auto preprocess = [&]() {
hip_check_error(hipMemsetAsync(
......
......@@ -940,10 +940,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const void* p_host_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args,
p_host_kernel_args,
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
p_host_kernel_args,
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
}
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
......
......@@ -557,12 +557,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
}
}
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() *
sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipGetErrorString(
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
......@@ -421,11 +421,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
hip_check_error(
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
......
......@@ -623,7 +623,7 @@ struct BlockUniversalGemmAsBsCr
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_.template LocalPrefetch(a_block_window, b_block_window);
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
}
// C += A * B
......@@ -632,7 +632,7 @@ struct BlockUniversalGemmAsBsCr
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window);
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
}
// C = A * B
......@@ -641,7 +641,7 @@ struct BlockUniversalGemmAsBsCr
const BSmemBlockWindow& b_block_window)
{
auto c_block_tensor = MakeCBlockTile();
block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window);
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
return c_block_tensor;
}
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import logging
import os
import subprocess
from dataclasses import replace
from functools import lru_cache
from typing import List
from ..util import library_path
from .op import CKBatchedGemmOperation
log = logging.getLogger(__name__)
def _ck_library_dir():
gemm_instances_path = os.path.join(
library_path(),
"src",
"tensor_operation_instance",
"gpu",
"gemm_universal_batched",
)
if not os.path.exists(gemm_instances_path):
log.error("CK library path %s does not exist", gemm_instances_path)
return None
return gemm_instances_path
def parse_instances(str_instances: List[str]) -> List[CKBatchedGemmOperation]:
"""
Parse the lines containing Universal Gemm template instances into `CKBatchedGemmOperation` instances
"""
def maybe_int(s):
try:
return int(s)
except ValueError:
return s
op_instances = []
for line in str_instances:
s_template_args = line.split("DeviceBatchedGemmMultiD_Xdl_CShuffle_V3")[
-1
].strip("<>, ")
template_args = []
i_current = 0
while i_current < len(s_template_args):
if s_template_args[i_current] == " ":
# skip whitespace
i_current += 1
continue
elif s_template_args[i_current : i_current + 2] == "S<":
# parse template S<Index...>
i_next = s_template_args.find(">", i_current)
template_args.append(
tuple(map(int, s_template_args[i_current + 2 : i_next].split(",")))
)
i_current = i_next + 2
else:
# all string attributes must be either type aliases or global constants in C++
i_next = s_template_args.find(",", i_current)
template_args.append(
maybe_int(
s_template_args[i_current : i_next if i_next != -1 else None]
)
)
if i_next != -1:
i_current = i_next + 1
if i_next == -1:
break
# ds layout and dtype are parsed as placeholder; reset value
template_args[2] = tuple() # ds layout
template_args[6] = tuple() # ds dtype
new_instance = CKBatchedGemmOperation(
*template_args, # type: ignore[arg-type]
)
op_instances.append(new_instance)
return op_instances
@lru_cache(None)
def gen_ops_library() -> List[CKBatchedGemmOperation]:
"""
Parse the Universal Gemm instances defined in the composable kernel library folder.
"""
ck_library_dir = _ck_library_dir()
if not ck_library_dir:
return []
grep_result = subprocess.run(
[
"grep",
"-inR",
"DeviceBatchedGemmMultiD_Xdl_CShuffle_V3",
_ck_library_dir(),
],
capture_output=True,
text=True,
)
op_instances = parse_instances(grep_result.stdout.strip().split("\n"))
log.debug("ck instances from library: %d", len(op_instances))
schedulers = [
"BlockGemmPipelineScheduler::Intrawave",
"BlockGemmPipelineScheduler::Interwave",
]
gemm_specs = [
"GemmSpecialization::Default",
"GemmSpecialization::MPadding",
"GemmSpecialization::NPadding",
"GemmSpecialization::KPadding",
"GemmSpecialization::MNPadding",
"GemmSpecialization::MKPadding",
"GemmSpecialization::NKPadding",
"GemmSpecialization::MNKPadding",
]
# substitute templated args by looping through their domains
substitute_instances = []
for instance in op_instances:
sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
sub_spec = instance.gemm_specialization == "GemmSpec"
schedulers_range = (
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
)
spec_range = gemm_specs if sub_spec else [instance.gemm_specialization]
for scheduler in schedulers_range:
for spec in spec_range:
substitute_instances.append(
replace(
instance,
block_gemm_pipeline_scheduler=scheduler,
gemm_specialization=spec,
)
)
return substitute_instances
if __name__ == "__main__":
print(gen_ops_library())
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from dataclasses import asdict, dataclass
from typing import Optional, Tuple
@dataclass
class CKBatchedGemmOperation:
"""
A python dataclass storing the template parameters of a CK Universal Gemm template instance
"""
a_layout: str
b_layout: str
ds_layouts: Tuple[str] # addmm specific
c_layout: str
a_element_dtype: str
b_element_dtype: str
ds_element_dtypes: Tuple[str] # addmm specific
c_element_dtype: str
acc_dtype: str
c_shuffle_dtype: str
a_elementwise_op: str
b_elementwise_op: str
c_elementwise_op: str
gemm_specialization: str
block_size: int
m_per_block: int
n_per_block: int
k_per_block: int
a_k1: int
b_k1: int
m_per_xdl: int
n_per_xdl: int
m_xdl_per_wave: int
n_xdl_per_wave: int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
a_block_transfer_src_access_order: Tuple[int, int, int]
a_block_transfer_src_vector_dim: int
a_block_transfer_src_scalar_per_vector: int
a_block_transfer_dst_scalar_per_vector_ak1: int
a_block_lds_extra_m: bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
b_block_transfer_src_access_order: Tuple[int, int, int]
b_block_transfer_src_vector_dim: int
b_block_transfer_src_scalar_per_vector: int
b_block_transfer_dst_scalar_per_vector_bk1: int
b_block_lds_extra_n: bool
c_shuffle_m_xdl_per_wave_per_shuffle: int
c_shuffle_n_xdl_per_wave_per_shuffle: int
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: (
Tuple[int, int, int, int]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block: Tuple[int]
block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: str
a_compute_dtype: Optional[str] = None
b_compute_dtype: Optional[str] = None
def name(self):
# cpp alias for template instance
return f"ck_device_batched_gemm_multi_d_xdl_c_shuffle_v3_{self.key_name()}"
def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key
return "_".join(
[
"K"
+ field_name.replace("_", "").lower()
+ "V"
+ (
"x".join(map(str, iter(field_value)))
if isinstance(field_value, tuple)
else str(field_value).replace(":", "")
)
for field_name, field_value in self.dict_items()
]
)
def dict_items(self):
return asdict(self).items()
......@@ -130,9 +130,7 @@ def gen_conv_ops_library() -> List[CKGroupedConvFwdOp]:
# substitute templated args by looping through their domains
substitute_instances = []
for instance in op_instances:
sub_scheduler = (
instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
)
sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
sub_spec = instance.conv_forward_specialization == "ConvSpec"
schedulers_range = (
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment