Unverified Commit b10d49b2 authored by Jiaxing Ding's avatar Jiaxing Ding Committed by GitHub
Browse files

[AMD] enable amd ci test & fix bug & fix dockerfile (#1244)

parent 468b1b70
......@@ -379,7 +379,7 @@ jobs:
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
./python/amd/test_tilelang_test_amd.py
./python/amd
# Apple Metal tests
- name: Run Metal tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
......
......@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash
......@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash
......@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash
......@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash
......@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash
......@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash
......@@ -23,6 +23,6 @@ RUN conda install pip cmake && conda install -c conda-forge libstdcxx-ng=12 && c
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && ./install_cuda.sh
&& cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash
......@@ -26,6 +26,6 @@ RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev z
RUN pip install cython
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main TileLang \
&& cd TileLang && cmake -S . -B build -DUSE_CUDA=ON && cmake --build build -j
&& cd TileLang && USE_CUDA=1 pip install -e . -v
CMD bash
......@@ -22,7 +22,7 @@ RUN conda run -n py_3.10 conda install pip cmake -y && \
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
conda run -n py_3.10 bash -c "cd tilelang && ./install_rocm.sh"
conda run -n py_3.10 bash -c "cd tilelang && USE_ROCM=1 pip install -e . -v"
RUN conda init bash
......
......@@ -22,15 +22,6 @@ def tl_matmul(
b_transposed=True,
k_pack=1,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
......@@ -190,6 +181,9 @@ def assert_tl_matmul_correctness(M,
if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
elif in_dtype == "float8_e4m3fnuz":
A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
......
......@@ -217,6 +217,9 @@ def assert_tl_matmul_correctness(M,
if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
elif in_dtype == "float8_e4m3fnuz":
A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
......@@ -264,11 +267,11 @@ def assert_tl_matmul_correctness(M,
@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
......@@ -283,6 +286,21 @@ def test_assert_tl_matmul():
k_pack=2,
b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
256,
256,
512,
"float8_e4m3fnuz",
"float32",
k_pack=2,
b_transposed=False,
b_preshuffle=True)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -374,8 +374,6 @@ class MatrixCoreIntrinEmitter:
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
print(a_local_stride, b_local_stride)
@T.macro
def _warp_mfma(A_local_buf, B_local_buf, C_local_buf):
for kp, i, j in T.grid(k_pack, warp_rows, warp_cols):
......@@ -678,34 +676,27 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
is_m_first: bool | None = False,
a_preshuffle: bool | None = False,
b_preshuffle: bool | None = False,
thread_var: Var | None = None,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mfma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self._initialize_k_pack(k_pack)
self._initialize_is_m_first(is_m_first)
super().__init__(
a_dtype=a_dtype,
b_dtype=b_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
reduce_k=reduce_k,
num_elems_per_byte=num_elems_per_byte,
k_pack=k_pack,
is_m_first=is_m_first,
thread_var=thread_var,
)
self._initialize_preshuffle(a_preshuffle, b_preshuffle)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.num_elems_per_byte = num_elems_per_byte
def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool):
if a_preshuffle is not None:
self.a_preshuffle = a_preshuffle
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment