"vscode:/vscode.git/clone" did not exist on "c90d1d8c895a3c8f9adac585368a8c96bbbaeb69"
Unverified Commit 886d14cc authored by Max Podkorytov's avatar Max Podkorytov Committed by GitHub
Browse files

modify python wrapper for addmm (#1441)

parent 6fc7bff5
......@@ -62,17 +62,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
i_current = i_next + 1
if i_next == -1:
break
# pad with `None`s for the fields which are not defined in the instance
template_args.insert(2, tuple()) # ds layout
template_args.insert(6, tuple()) # ds dtype
new_instance = CKGemmOperation(
*template_args, # type: ignore[arg-type]
*((None,) * (len(fields(CKGemmOperation)) - len(template_args))),
)
# the last 2 template parameters are optional
# if they are absent, substitute them with default values from Universal Gemm C++ template declaration
if new_instance.a_compute_dtype is None:
new_instance.a_compute_dtype = new_instance.c_element_dtype
if new_instance.b_compute_dtype is None:
new_instance.b_compute_dtype = new_instance.c_element_dtype
op_instances.append(new_instance)
return op_instances
......@@ -208,6 +204,8 @@ def gen_ops_preselected() -> List[CKGemmOperation]:
a_layout="Row",
b_layout="Col",
c_layout="Row",
ds_element_dtypes=tuple(),
ds_layouts=tuple(),
a_element_dtype="F16",
b_element_dtype="F16",
c_element_dtype="F16",
......
......@@ -10,10 +10,12 @@ class CKGemmOperation:
a_layout: str
b_layout: str
ds_layouts: Tuple[str] # addmm specific
c_layout: str
a_element_dtype: str
b_element_dtype: str
ds_element_dtypes: Tuple[str] # addmm specific
c_element_dtype: str
acc_dtype: str
......@@ -64,16 +66,15 @@ class CKGemmOperation:
Tuple[int, int, int, int]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: Optional[str]
block_gemm_pipeline_version: str
a_compute_dtype: Optional[str]
b_compute_dtype: Optional[str]
a_compute_dtype: Optional[str] = None
b_compute_dtype: Optional[str] = None
def name(self):
# cpp alias for template instance
return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}"
return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}"
def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment