Commit 0e886dab authored by wenjh's avatar wenjh
Browse files

Merge develop_v2.4


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents e56de127 b944277c
...@@ -8,6 +8,7 @@ import triton ...@@ -8,6 +8,7 @@ import triton
import triton.language as tl import triton.language as tl
import pandas as pd import pandas as pd
from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_helper from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_helper
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools import functools
import logging import logging
...@@ -557,7 +558,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int, ...@@ -557,7 +558,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
k: int, k: int,
out_dtype: Type[torch.dtype] = torch.float16, out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda", device: str = "cuda",
block_size: List[int]=[128,128], block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None): best_config:Optional[dict]=None):
...@@ -596,7 +597,7 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -596,7 +597,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
k: int, k: int,
out_dtype: Type[torch.dtype] = torch.float16, out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda", device: str = "cuda",
block_size: List[int]=[128,128], block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None): best_config:Optional[dict]=None):
...@@ -771,7 +772,7 @@ def main(): ...@@ -771,7 +772,7 @@ def main():
n_list=[7168] n_list=[7168]
k_list=[1152] k_list=[1152]
block_size=[128, 128] block_size=[blockwise_fp8_block_len, blockwise_fp8_block_len]
out_dtype=torch.bfloat16 out_dtype=torch.bfloat16
......
...@@ -8,6 +8,7 @@ import triton ...@@ -8,6 +8,7 @@ import triton
import triton.language as tl import triton.language as tl
import pandas as pd import pandas as pd
from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_helper_b from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_helper_b
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools import functools
import logging import logging
...@@ -596,7 +597,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int, ...@@ -596,7 +597,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
k: int, k: int,
out_dtype: Type[torch.dtype] = torch.float16, out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda", device: str = "cuda",
block_size: List[int]=[128,128], block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None): best_config:Optional[dict]=None):
...@@ -639,7 +640,7 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -639,7 +640,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
k: int, k: int,
out_dtype: Type[torch.dtype] = torch.float16, out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda", device: str = "cuda",
block_size: List[int]=[128,128], block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None): best_config:Optional[dict]=None):
...@@ -821,7 +822,7 @@ def main(): ...@@ -821,7 +822,7 @@ def main():
n_list=[7168] n_list=[7168]
k_list=[1152] k_list=[1152]
block_size=[128, 128] block_size=[blockwise_fp8_block_len, blockwise_fp8_block_len]
out_dtype=torch.bfloat16 out_dtype=torch.bfloat16
......
...@@ -9,6 +9,7 @@ import triton.language as tl ...@@ -9,6 +9,7 @@ import triton.language as tl
import pandas as pd import pandas as pd
import logging import logging
import math import math
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
def to_int8(tensor: torch.Tensor): def to_int8(tensor: torch.Tensor):
...@@ -118,7 +119,7 @@ def _int8_gemm_helper(m: int, ...@@ -118,7 +119,7 @@ def _int8_gemm_helper(m: int,
k: int, k: int,
out_dtype: Type[torch.dtype] = torch.float16, out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda", device: str = "cuda",
block_size: List[int]=[128,128], block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
best_config:Optional[list] = None): best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization # Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization. # and per-output channel weight quantization.
...@@ -143,7 +144,7 @@ def _int8_gemm_helper_b(m: int, ...@@ -143,7 +144,7 @@ def _int8_gemm_helper_b(m: int,
k: int, k: int,
out_dtype: Type[torch.dtype] = torch.float16, out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda", device: str = "cuda",
block_size: List[int]=[128,128], block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
best_config:Optional[list] = None): best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization # Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization. # and per-output channel weight quantization.
...@@ -168,7 +169,7 @@ def _int8_gemm_helper_test(m: int, ...@@ -168,7 +169,7 @@ def _int8_gemm_helper_test(m: int,
k: int, k: int,
out_dtype: Type[torch.dtype] = torch.float16, out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda", device: str = "cuda",
block_size: List[int]=[128,128], block_size: List[int]=[blockwise_fp8_block_len,blockwise_fp8_block_len],
best_config:Optional[list] = None): best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization # Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization. # and per-output channel weight quantization.
...@@ -197,7 +198,7 @@ def main(): ...@@ -197,7 +198,7 @@ def main():
m_list=[1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768] m_list=[1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768]
n_list=[576,2048,7168,256,7168,1536,1536] n_list=[576,2048,7168,256,7168,1536,1536]
k_list=[7168,512,1024,7168,128,7168,1536] k_list=[7168,512,1024,7168,128,7168,1536]
block_size=[128,128] block_size=[blockwise_fp8_block_len,blockwise_fp8_block_len]
out_dtype=torch.bfloat16 out_dtype=torch.bfloat16
_n=[] _n=[]
_k=[] _k=[]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment