Commit 550a1e5e authored by zhuwenwen's avatar zhuwenwen
Browse files

Update int8_utils.py

parent d8b9028d
...@@ -446,10 +446,10 @@ def w8a8_block_int8_matmul( ...@@ -446,10 +446,10 @@ def w8a8_block_int8_matmul(
C_shape = A.shape[:-1] + (N, ) C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype) C = A.new_empty(C_shape, dtype=output_dtype)
if len(W8A8_TRITONJSON.triton_json_list)==0: if len(W8A8_TRITONJSON.triton_json_dict)==0:
config=None config=None
elif f"1_{N}_{K}_block[{block_n},{block_k}]" in W8A8_TRITONJSON.triton_json_list[0]: elif f"1_{N}_{K}_block[{block_n},{block_k}]" in W8A8_TRITONJSON.triton_json_dict:
if M<=16: if M<=16:
m_=M m_=M
elif M<=64: elif M<=64:
...@@ -472,7 +472,7 @@ def w8a8_block_int8_matmul( ...@@ -472,7 +472,7 @@ def w8a8_block_int8_matmul(
else: else:
m_=8192 m_=8192
config=W8A8_TRITONJSON.triton_json_list[0][f"{m_}_{N}_{K}_block[{block_n},{block_k}]"] config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{N}_{K}_block[{block_n},{block_k}]"]
else: else:
config=None config=None
...@@ -617,4 +617,4 @@ def block_dequant( ...@@ -617,4 +617,4 @@ def block_dequant(
i * block_k : min((i + 1) * block_k, k), i * block_k : min((i + 1) * block_k, k),
] *= x_s[j][i] ] *= x_s[j][i]
return x_dq_block return x_dq_block
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment