Commit b89a5049 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents be58e518 886d14cc
......@@ -46,8 +46,10 @@ if(GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
if(GPU_TARGETS MATCHES "gfx94")
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp)
......@@ -128,8 +130,10 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
if(GPU_TARGETS MATCHES "gfx94")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
endif()
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance)
......
......@@ -62,17 +62,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
i_current = i_next + 1
if i_next == -1:
break
# pad with `None`s for the fields which are not defined in the instance
template_args.insert(2, tuple()) # ds layout
template_args.insert(6, tuple()) # ds dtype
new_instance = CKGemmOperation(
*template_args, # type: ignore[arg-type]
*((None,) * (len(fields(CKGemmOperation)) - len(template_args))),
)
# the last 2 template parameters are optional
# if they are absent, substitute them with default values from Universal Gemm C++ template declaration
if new_instance.a_compute_dtype is None:
new_instance.a_compute_dtype = new_instance.c_element_dtype
if new_instance.b_compute_dtype is None:
new_instance.b_compute_dtype = new_instance.c_element_dtype
op_instances.append(new_instance)
return op_instances
......@@ -208,6 +204,8 @@ def gen_ops_preselected() -> List[CKGemmOperation]:
a_layout="Row",
b_layout="Col",
c_layout="Row",
ds_element_dtypes=tuple(),
ds_layouts=tuple(),
a_element_dtype="F16",
b_element_dtype="F16",
c_element_dtype="F16",
......
......@@ -10,10 +10,12 @@ class CKGemmOperation:
a_layout: str
b_layout: str
ds_layouts: Tuple[str] # addmm specific
c_layout: str
a_element_dtype: str
b_element_dtype: str
ds_element_dtypes: Tuple[str] # addmm specific
c_element_dtype: str
acc_dtype: str
......@@ -64,16 +66,15 @@ class CKGemmOperation:
Tuple[int, int, int, int]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: Optional[str]
block_gemm_pipeline_version: str
a_compute_dtype: Optional[str]
b_compute_dtype: Optional[str]
a_compute_dtype: Optional[str] = None
b_compute_dtype: Optional[str] = None
def name(self):
# cpp alias for template instance
return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}"
return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}"
def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment