Commit b89a5049 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents be58e518 886d14cc
...@@ -46,8 +46,10 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -46,8 +46,10 @@ if(GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
endif() endif()
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) if(GPU_TARGETS MATCHES "gfx94")
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp)
...@@ -128,8 +130,10 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -128,8 +130,10 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) if(GPU_TARGETS MATCHES "gfx94")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
endif()
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance)
......
...@@ -62,17 +62,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]: ...@@ -62,17 +62,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
i_current = i_next + 1 i_current = i_next + 1
if i_next == -1: if i_next == -1:
break break
# pad with `None`s for the fields which are not defined in the instance
template_args.insert(2, tuple()) # ds layout
template_args.insert(6, tuple()) # ds dtype
new_instance = CKGemmOperation( new_instance = CKGemmOperation(
*template_args, # type: ignore[arg-type] *template_args, # type: ignore[arg-type]
*((None,) * (len(fields(CKGemmOperation)) - len(template_args))),
) )
# the last 2 template parameters are optional
# if they are absent, substitute them with default values from Universal Gemm C++ template declaration
if new_instance.a_compute_dtype is None:
new_instance.a_compute_dtype = new_instance.c_element_dtype
if new_instance.b_compute_dtype is None:
new_instance.b_compute_dtype = new_instance.c_element_dtype
op_instances.append(new_instance) op_instances.append(new_instance)
return op_instances return op_instances
...@@ -208,6 +204,8 @@ def gen_ops_preselected() -> List[CKGemmOperation]: ...@@ -208,6 +204,8 @@ def gen_ops_preselected() -> List[CKGemmOperation]:
a_layout="Row", a_layout="Row",
b_layout="Col", b_layout="Col",
c_layout="Row", c_layout="Row",
ds_element_dtypes=tuple(),
ds_layouts=tuple(),
a_element_dtype="F16", a_element_dtype="F16",
b_element_dtype="F16", b_element_dtype="F16",
c_element_dtype="F16", c_element_dtype="F16",
......
...@@ -10,10 +10,12 @@ class CKGemmOperation: ...@@ -10,10 +10,12 @@ class CKGemmOperation:
a_layout: str a_layout: str
b_layout: str b_layout: str
ds_layouts: Tuple[str] # addmm specific
c_layout: str c_layout: str
a_element_dtype: str a_element_dtype: str
b_element_dtype: str b_element_dtype: str
ds_element_dtypes: Tuple[str] # addmm specific
c_element_dtype: str c_element_dtype: str
acc_dtype: str acc_dtype: str
...@@ -64,16 +66,15 @@ class CKGemmOperation: ...@@ -64,16 +66,15 @@ class CKGemmOperation:
Tuple[int, int, int, int] Tuple[int, int, int, int]
) )
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
block_gemm_pipeline_scheduler: str block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: Optional[str] block_gemm_pipeline_version: str
a_compute_dtype: Optional[str] a_compute_dtype: Optional[str] = None
b_compute_dtype: Optional[str] b_compute_dtype: Optional[str] = None
def name(self): def name(self):
# cpp alias for template instance # cpp alias for template instance
return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}" return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}"
def key_name(self): def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key # TBD; must be unique per instance. Intended to use as dict key
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment