Commit b944277c authored by wenjh's avatar wenjh
Browse files

[Blockwise] Add support block_len=64 support



Add env to chose blocklen of blockwise quantize.
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Fix pytest of blockwise error
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Resolve new api in  int8 gemm test
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Fix incorrect launch parm
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Fix 1D blockwise(64) acc error
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 251dcc7e
......@@ -8,6 +8,7 @@ import triton
import triton.language as tl
import pandas as pd
from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_helper
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
......@@ -557,7 +558,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None):
......@@ -596,7 +597,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None):
......@@ -771,7 +772,7 @@ def main():
n_list=[7168]
k_list=[1152]
block_size=[128, 128]
block_size=[blockwise_fp8_block_len, blockwise_fp8_block_len]
out_dtype=torch.bfloat16
......
......@@ -8,6 +8,7 @@ import triton
import triton.language as tl
import pandas as pd
from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_helper_b
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
......@@ -596,7 +597,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None):
......@@ -639,7 +640,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None):
......@@ -821,7 +822,7 @@ def main():
n_list=[7168]
k_list=[1152]
block_size=[128, 128]
block_size=[blockwise_fp8_block_len, blockwise_fp8_block_len]
out_dtype=torch.bfloat16
......
......@@ -9,6 +9,7 @@ import triton.language as tl
import pandas as pd
import logging
import math
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
def to_int8(tensor: torch.Tensor):
......@@ -118,7 +119,7 @@ def _int8_gemm_helper(m: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
......@@ -143,7 +144,7 @@ def _int8_gemm_helper_b(m: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
......@@ -168,7 +169,7 @@ def _int8_gemm_helper_test(m: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
......@@ -197,7 +198,7 @@ def main():
m_list=[1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768]
n_list=[576,2048,7168,256,7168,1536,1536]
k_list=[7168,512,1024,7168,128,7168,1536]
block_size=[128,128]
block_size=[blockwise_fp8_block_len,blockwise_fp8_block_len]
out_dtype=torch.bfloat16
_n=[]
_k=[]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment