Commit dbf9fd61 authored by lishen's avatar lishen
Browse files

Merge branch 'quant_main' into 'main'

量化scales传输size优化

See merge request dcutoolkit/deeplearing/DeepEP!20
parents d0fcf024 e57e9270
...@@ -135,8 +135,8 @@ struct LowLatencyLayout { ...@@ -135,8 +135,8 @@ struct LowLatencyLayout {
} }
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) { int num_ranks, int num_experts, int quant_group_size=0) {
const int num_scales = hidden / QUANTIZATION_GROUPSIZE; const int num_scales = quant_group_size == 0 ? 4 : hidden / QUANTIZATION_GROUPSIZE; // 应该是1,但是代码中为了满足int4对齐
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
...@@ -205,9 +205,9 @@ struct LowLatencyLayout { ...@@ -205,9 +205,9 @@ struct LowLatencyLayout {
}; };
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) { int num_ranks, int num_experts, int quant_group_size=0) {
auto num_bytes = auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts) LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size)
.total_bytes; .total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES; NUM_BUFFER_ALIGNMENT_BYTES;
......
...@@ -1271,10 +1271,10 @@ Buffer::internode_combine( ...@@ -1271,10 +1271,10 @@ Buffer::internode_combine(
#endif #endif
} }
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts, int quant_group_size) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size);
auto clean_meta_0 = layout.buffers[0].clean_meta(); auto clean_meta_0 = layout.buffers[0].clean_meta();
auto clean_meta_1 = layout.buffers[1].clean_meta(); auto clean_meta_1 = layout.buffers[1].clean_meta();
...@@ -1311,7 +1311,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1311,7 +1311,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto num_local_experts = num_experts / num_ranks; auto num_local_experts = num_experts / num_ranks;
// Buffer control // Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
......
...@@ -172,7 +172,7 @@ public: ...@@ -172,7 +172,7 @@ public:
std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream); std::optional<EventHandle> &previous_event, bool async, bool allocate_on_comm_stream);
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden,
int num_experts); int num_experts, int quant_group_size=0);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
......
...@@ -210,13 +210,13 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -210,13 +210,13 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
// Message package: hidden data, FP8 scales, index at source // Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use // NOTES: currently we have 3 reserved int fields for future use
using vec_t = typename std::conditional<kUseQuant8Bit, int2, int4>::type; using vec_t = typename std::conditional<kUseQuant8Bit, int2, int4>::type;
constexpr size_t num_bytes_per_msg = sizeof(int4) + (kUseQuant8Bit ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16))); constexpr size_t num_bytes_per_msg = sizeof(int4) +
(kUseQuant8Bit ? (kHidden + (kQuantGroupSize == 0 ? 4 : kNumScales) * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
EP_STATIC_ASSERT(num_bytes_per_msg % sizeof(int4) == 0, "Invalid message size"); EP_STATIC_ASSERT(num_bytes_per_msg % sizeof(int4) == 0, "Invalid message size");
constexpr size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); constexpr size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
// Expert counts // Expert counts
constexpr int kNumMaxWarpGroups = 1024 / kWarpSize; __shared__ int shared_num_tokens_sent_per_expert[kMaxNumWarps];
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
// Sending phase // Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0) if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
...@@ -230,7 +230,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -230,7 +230,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
constexpr int kNumThreadPerGroup = QUANTIZATION_GROUPSIZE / kNumElemsPerRead; constexpr int kNumThreadPerGroup = QUANTIZATION_GROUPSIZE / kNumElemsPerRead;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); // EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * kWarpSize; const auto num_threads = num_warps * kWarpSize;
constexpr int hidden_bf16_int4 = kHidden / kNumElemsPerRead; constexpr int hidden_bf16_int4 = kHidden / kNumElemsPerRead;
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
...@@ -375,7 +375,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void ...@@ -375,7 +375,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
} }
// This SM should be responsible for some destination experts, read `topk_idx` for them // This SM should be responsible for some destination experts, read `topk_idx` for them
int expert_count[kNumMaxWarpGroups] = {0}; int expert_count[kMaxNumWarps] = {0};
const auto expert_begin_idx = sm_id * num_warp_groups; const auto expert_begin_idx = sm_id * num_warp_groups;
const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);
...@@ -465,7 +465,7 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -465,7 +465,7 @@ LOW_LATENCY_DISPATCH_RECV:
(kQuantGroupSize == 0 ? 1 : num_aligned_scales); (kQuantGroupSize == 0 ? 1 : num_aligned_scales);
// Shared between sub-warps in warp groups // Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; __shared__ int shared_num_recv_tokens[kMaxNumWarps], shared_recv_token_begin_idx[kMaxNumWarps];
// Wait tokens to arrive // Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0 // NOTES: using sub-warp 1 to overlap with sub-warp 0
......
...@@ -212,7 +212,7 @@ class Buffer: ...@@ -212,7 +212,7 @@ class Buffer:
@staticmethod @staticmethod
def get_low_latency_rdma_size_hint( def get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int, quant_group_size: int = 0
) -> int: ) -> int:
""" """
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16. Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
...@@ -222,12 +222,13 @@ class Buffer: ...@@ -222,12 +222,13 @@ class Buffer:
hidden: the hidden dimension of each token. hidden: the hidden dimension of each token.
num_ranks: the number of EP group ranks. num_ranks: the number of EP group ranks.
num_experts: the number of all experts. num_experts: the number of all experts.
quant_group_size: the group size if use quant.
Returns: Returns:
size: the RDMA buffer size recommended. size: the RDMA buffer size recommended.
""" """
return deep_ep_cpp.get_low_latency_rdma_size_hint( return deep_ep_cpp.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size
) )
def get_comm_stream(self) -> torch.Stream: def get_comm_stream(self) -> torch.Stream:
...@@ -823,7 +824,7 @@ class Buffer: ...@@ -823,7 +824,7 @@ class Buffer:
return combined_x, combined_topk_weights, EventOverlap(event) return combined_x, combined_topk_weights, EventOverlap(event)
def clean_low_latency_buffer( def clean_low_latency_buffer(
self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int, quant_group_size: int = 0
) -> None: ) -> None:
""" """
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
...@@ -835,8 +836,9 @@ class Buffer: ...@@ -835,8 +836,9 @@ class Buffer:
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
hidden: the hidden dimension of each token. hidden: the hidden dimension of each token.
num_experts: the number of all experts. num_experts: the number of all experts.
quant_group_size: the group size if use quant.
""" """
self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts) self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts, quant_group_size)
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
......
...@@ -6,7 +6,7 @@ import torch.distributed as dist ...@@ -6,7 +6,7 @@ import torch.distributed as dist
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import deep_ep import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back, hash_tensor from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_pg_back, hash_tensor
# Test compatibility with low latency functions # Test compatibility with low latency functions
import test_low_latency import test_low_latency
...@@ -127,7 +127,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -127,7 +127,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
hash_value += hash_tensor(recv_x[0]) hash_value += hash_tensor(recv_x[0])
hash_value += hash_tensor(recv_x[1]) hash_value += hash_tensor(recv_x[1])
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x recv_x = per_token_cast_pg_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks # Checks
recv_gbl_rank_prefix_sum = handle[-4] recv_gbl_rank_prefix_sum = handle[-4]
...@@ -153,7 +153,7 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -153,7 +153,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
dispatch_args.update({'previous_event': buffer.capture()}) dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x recv_x = per_token_cast_pg_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if not is_rand: if not is_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum) check_data(recv_x, recv_gbl_rank_prefix_sum)
......
...@@ -5,7 +5,7 @@ import torch.distributed as dist ...@@ -5,7 +5,7 @@ import torch.distributed as dist
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import deep_ep import deep_ep
from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_pg_back
# Test compatibility with low latency functions # Test compatibility with low latency functions
import test_low_latency import test_low_latency
...@@ -99,7 +99,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks ...@@ -99,7 +99,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args.update({'previous_event': buffer.capture()}) dispatch_args.update({'previous_event': buffer.capture()})
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x recv_x = per_token_cast_pg_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks # Checks
rank_prefix_matrix = handle[0] rank_prefix_matrix = handle[0]
...@@ -126,7 +126,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks ...@@ -126,7 +126,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args.update({'num_worst_tokens': num_worst_tokens}) dispatch_args.update({'num_worst_tokens': num_worst_tokens})
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args) recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x recv_worst_x = per_token_cast_pg_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
assert len(empty_list) == 0 assert len(empty_list) == 0
assert num_worst_tokens == recv_worst_x.size(0) assert num_worst_tokens == recv_worst_x.size(0)
assert num_worst_tokens == recv_worst_topk_idx.size(0) assert num_worst_tokens == recv_worst_topk_idx.size(0)
...@@ -143,7 +143,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks ...@@ -143,7 +143,7 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
dispatch_args.update({'previous_event': buffer.capture()}) dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else () event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x recv_x = per_token_cast_pg_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if current_x is not x_pure_rand: if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix) check_data(recv_x, rank_prefix_matrix)
......
...@@ -4,7 +4,7 @@ import torch.distributed as dist ...@@ -4,7 +4,7 @@ import torch.distributed as dist
from functools import partial from functools import partial
import deep_ep import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
...@@ -44,7 +44,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -44,7 +44,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8)) # print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8))
# return # return
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
if dispatch_use_fp8 else packed_recv_x.clone() if dispatch_use_fp8 else packed_recv_x.clone()
# print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n") # print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n")
# print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n") # print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n")
...@@ -53,7 +53,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -53,7 +53,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
for i in range(num_local_experts if do_check else 0): for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices # Check expert indices
......
...@@ -72,16 +72,16 @@ def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor): ...@@ -72,16 +72,16 @@ def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor):
x_scales = x_scales.view(x.size(0), -1, 1) x_scales = x_scales.view(x.size(0), -1, 1)
return (x_fp32_padded * x_scales).view(x_padded.shape).to(torch.bfloat16)[:,:n].contiguous() return (x_fp32_padded * x_scales).view(x_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
def per_token_cast_pc_back(x_int8: torch.Tensor, x_scales: torch.Tensor): def per_token_cast_pc_back(x: torch.Tensor, x_scales: torch.Tensor):
if x_int8.numel() == 0: if x.numel() == 0:
return x_int8.to(torch.bfloat16) return x.to(torch.bfloat16)
assert x_int8.dim() == 2 assert x.dim() == 2
m, n = x_int8.shape m, n = x.shape
aligned_n = align_up(n, 128) aligned_n = align_up(n, 128)
x_int8_padded = torch.nn.functional.pad(x_int8, (0, aligned_n - n), mode='constant', value=0) x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)
x_fp32_padded = x_int8_padded.to(torch.float32).view(m, -1, 1) x_fp32_padded = x_padded.to(torch.float32).view(m, -1, 1)
x_scales = x_scales.view(m, -1, 1).to(torch.float32) x_scales = x_scales.view(m, -1, 1).to(torch.float32)
x_deq = (x_fp32_padded * x_scales).view(m, aligned_n) x_deq = (x_fp32_padded * x_scales).view(m, aligned_n)
return x_deq[:, :n].to(torch.bfloat16).contiguous() return x_deq[:, :n].to(torch.bfloat16).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment