Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
import math import math
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits):
is_causal_or_local, max_splits):
""" """
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
......
...@@ -128,9 +128,7 @@ def per_token_group_quant_fp8( ...@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization. scaling factor for quantization.
""" """
assert (x.shape[-1] % assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}"
group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous" assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment