Commit 3852d58b authored by wangziyang's avatar wangziyang
Browse files

update cp_async & init inject_ds_read

parent 19cdf0ca
......@@ -29,24 +29,28 @@ class GemmMMA(GemmBase):
chunk=self.chunk,
)
if self.is_gemm_ss():
print("gemm_ss")
return {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_sr():
print("gemm_ss")
return {
self.A: make_swizzled_layout(self.A),
self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rs():
print("gemm_ss")
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rr():
print("gemm_ss")
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: mma_emitter.make_mma_load_layout(self.B, matrix="B"),
......
......@@ -338,9 +338,28 @@ def InjectPTXAsyncCopy():
fpass : tvm.transform.Pass
The result pass
"""
print("Injecting PTX async copy for global to shared memory copy on DCU.")
return _ffi_api.InjectPTXAsyncCopy() # type: ignore
def InjectDSRead():
"""Rewrite shared memory to register load using ds_read hardware instructions on DCU.
This pass replaces BufferLoad from shared memory with ds_read_b64 or
ds_read_m32x16_b16 hardware instructions for AMD DCU (gfx936, gfx942, etc.).
- ds_read_b64: loads 8 bytes (4 halfs or 2 floats) at once
- ds_read_m32x16_b16: loads 32 bytes (16 halfs) at once
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
print("Injecting ds_read for shared to register memory copy on DCU.")
return _ffi_api.InjectDSRead() # type: ignore
def LowerDeviceStorageAccessInfo():
"""Lower attached storage access information on device.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment