Commit 99d12b98 authored by Xtra's avatar Xtra Committed by GitHub
Browse files

fix bias epilogue (#141)

parent a5138ed3
......@@ -4,6 +4,7 @@
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
......@@ -60,6 +61,9 @@ struct Mxfp6Mxfp8GemmSm120 {
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
// use per-column bias, i.e. every column has different bias
using EVTOp = cutlass::epilogue::fusion::LinCombPerColBias<ElementD, ElementAccumulator>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
......@@ -67,7 +71,8 @@ struct Mxfp6Mxfp8GemmSm120 {
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy
EVTOp
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
......@@ -127,7 +132,7 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
if (bias){
auto stride_bias = cutlass::make_cute_packed_stride(Mxfp6Mxfp8GemmSm120::StrideC{}, {});
using StrideBias = Stride<cutlass::_0, cutlass::_1, int64_t>;
typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
......@@ -143,12 +148,16 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementC const*>(bias->data_ptr()),
stride_bias,
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementC const*>(D.data_ptr()),
stride_D,
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
fusion_args.bias_ptr = static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementC const*>(bias->data_ptr());
fusion_args.dBias = StrideBias{};
return arguments;
} else {
typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments arguments{
......@@ -171,6 +180,8 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
return arguments;
}
}
......
......@@ -4,6 +4,7 @@
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
......@@ -60,6 +61,9 @@ struct Mxfp8GemmSm120 {
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
// use per-column bias, i.e. every column has different bias
using EVTOp = cutlass::epilogue::fusion::LinCombPerColBias<ElementD, ElementAccumulator>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
......@@ -67,7 +71,8 @@ struct Mxfp8GemmSm120 {
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy
EVTOp
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
......@@ -127,7 +132,7 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
if (bias){
auto stride_bias = cutlass::make_cute_packed_stride(Mxfp8GemmSm120::StrideC{}, {});
using StrideBias = Stride<cutlass::_0, cutlass::_1, int64_t>;
typename Mxfp8GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
......@@ -143,12 +148,16 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Mxfp8GemmSm120::Gemm::ElementC const*>(bias->data_ptr()),
stride_bias,
static_cast<Mxfp8GemmSm120::Gemm::ElementC const*>(D.data_ptr()),
stride_D,
static_cast<Mxfp8GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
fusion_args.bias_ptr = static_cast<Mxfp8GemmSm120::Gemm::ElementC const*>(bias->data_ptr());
fusion_args.dBias = StrideBias{};
return arguments;
} else {
typename Mxfp8GemmSm120::Gemm::Arguments arguments{
......@@ -171,6 +180,8 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
return arguments;
}
}
......
......@@ -4,6 +4,7 @@
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
......@@ -60,6 +61,9 @@ struct Fp4GemmSm120 {
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
// use per-column bias, i.e. every column has different bias
using EVTOp = cutlass::epilogue::fusion::LinCombPerColBias<ElementD, ElementAccumulator>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
......@@ -67,7 +71,8 @@ struct Fp4GemmSm120 {
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy
EVTOp
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
......@@ -127,7 +132,7 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
if (bias){
auto stride_bias = cutlass::make_cute_packed_stride(Fp4GemmSm120::StrideC{}, {});
using StrideBias = Stride<cutlass::_0, cutlass::_1, int64_t>;
typename Fp4GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
......@@ -143,12 +148,16 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Fp4GemmSm120::Gemm::ElementC const*>(bias->data_ptr()),
stride_bias,
static_cast<Fp4GemmSm120::Gemm::ElementC const*>(D.data_ptr()),
stride_D,
static_cast<Fp4GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
fusion_args.bias_ptr = static_cast<Fp4GemmSm120::Gemm::ElementC const*>(bias->data_ptr());
fusion_args.dBias = StrideBias{};
return arguments;
} else {
typename Fp4GemmSm120::Gemm::Arguments arguments{
......@@ -171,6 +180,8 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
return arguments;
}
}
......
import torch
import time
from test_bench import MMWeightMxfp8ActMxfp6
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
mm = MMWeightMxfp8ActMxfp6(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
# warmup
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp8ActMxfp6(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
......@@ -27,10 +27,12 @@ class TestQuantBF162MXFP6(unittest.TestCase):
weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
weight_quant_pred, weight_scale_pred = scaled_fp6_quant(weight)
bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10
alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32)
mm_pred = cutlass_scaled_mxfp6_mxfp8_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha)
mm_pred = cutlass_scaled_mxfp6_mxfp8_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha, bias=bias)
mm_real = linear(activation, weight, bias=None).to(torch.bfloat16)
mm_real = linear(activation, weight, bias=bias).to(torch.bfloat16)
self.assertTrue(error(mm_pred, mm_real) < 1e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.")
......
import torch
import time
from test_bench import MMWeightMxfp8
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
mm = MMWeightMxfp8(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
# warmup
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp8(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
......@@ -27,10 +27,12 @@ class TestQuantBF162MXFP8(unittest.TestCase):
weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
weight_quant_pred, weight_scale_pred = scaled_fp8_quant(weight)
bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10
alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32)
mm_pred = cutlass_scaled_mxfp8_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha)
mm_pred = cutlass_scaled_mxfp8_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha, bias=bias)
mm_real = linear(activation, weight, bias=None).to(torch.bfloat16)
mm_real = linear(activation, weight, bias=bias).to(torch.bfloat16)
self.assertTrue(error(mm_pred, mm_real) < 1e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.")
......
......@@ -8,7 +8,7 @@ def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50
mm = MMWeightFp4(weight, bias)
......@@ -53,7 +53,7 @@ def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment