Commit eed650cb authored by helloyongyang's avatar helloyongyang
Browse files

add mxfp6_mxfp8 mm kernel

parent cf1358e7
...@@ -24,3 +24,4 @@ ...@@ -24,3 +24,4 @@
*.mp4 *.mp4
build/ build/
dist/ dist/
.cache/
...@@ -31,7 +31,7 @@ else() ...@@ -31,7 +31,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
repo-cutlass repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG b244379d9b15574e07b73b814b88bd2233f0b3ce GIT_TAG b995f933179c22d3fe0d871c3a53d11e4681950f
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_MakeAvailable(repo-cutlass) FetchContent_MakeAvailable(repo-cutlass)
......
...@@ -38,20 +38,20 @@ pip install dist/*whl --force-reinstall --no-deps ...@@ -38,20 +38,20 @@ pip install dist/*whl --force-reinstall --no-deps
##### cos and speed test, mm without bias ##### cos and speed test, mm without bias
``` ```
python test/test_bench2.py python test/nvfp4_nvfp4/test_bench2.py
``` ```
##### cos and speed test, mm with bias ##### cos and speed test, mm with bias
``` ```
python test/test_bench3_bias.py python test/nvfp4_nvfp4/test_bench3_bias.py
``` ```
##### Bandwidth utilization test for quant ##### Bandwidth utilization test for quant
``` ```
python test/test_quant_mem_utils.py python test/nvfp4_nvfp4/test_quant_mem_utils.py
``` ```
##### tflops test for mm ##### tflops test for mm
``` ```
python test/test_mm_tflops.py python test/nvfp4_nvfp4/test_mm_tflops.py
``` ```
...@@ -16,6 +16,11 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) { ...@@ -16,6 +16,11 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()"); " Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant_sm120", torch::kCUDA, &scaled_fp4_quant_sm120); m.impl("scaled_fp4_quant_sm120", torch::kCUDA, &scaled_fp4_quant_sm120);
m.def(
"cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
"alpha, Tensor? bias) -> ()");
m.impl("cutlass_scaled_mxfp6_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp6_mxfp8_mm_sm120);
} }
REGISTER_EXTENSION(common_ops) REGISTER_EXTENSION(common_ops)
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
using namespace cute;
struct Mxfp6Mxfp8GemmSm120 {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
static constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::mx_float6_t<cutlass::float_e3m2_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
static constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
};
// Populates a Gemm::Arguments structure from the given commandline options
typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
at::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias,
int64_t M,
int64_t N,
int64_t K) {
using Sm1xxBlkScaledConfig = typename Mxfp6Mxfp8GemmSm120::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
int m = static_cast<int>(M);
int n = static_cast<int>(N);
int k = static_cast<int>(K);
auto stride_A = cutlass::make_cute_packed_stride(Mxfp6Mxfp8GemmSm120::StrideA{}, {m, k, 1});
auto stride_B = cutlass::make_cute_packed_stride(Mxfp6Mxfp8GemmSm120::StrideB{}, {n, k, 1});
auto stride_D = cutlass::make_cute_packed_stride(Mxfp6Mxfp8GemmSm120::StrideD{}, {m, n, 1});
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
if (bias){
auto stride_bias = cutlass::make_cute_packed_stride(Mxfp6Mxfp8GemmSm120::StrideC{}, {});
typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{// Mainloop arguments
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementA const*>(A.data_ptr()),
stride_A,
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementB const*>(B.data_ptr()),
stride_B,
static_cast<cutlass::float_ue8m0_t const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<cutlass::float_ue8m0_t const*>(B_sf.data_ptr()),
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementC const*>(bias->data_ptr()),
stride_bias,
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
return arguments;
} else {
typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{// Mainloop arguments
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementA const*>(A.data_ptr()),
stride_A,
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementB const*>(B.data_ptr()),
stride_B,
static_cast<cutlass::float_ue8m0_t const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<cutlass::float_ue8m0_t const*>(B_sf.data_ptr()),
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementC const*>(D.data_ptr()),
stride_D,
static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
return arguments;
}
}
void runGemmMxfp6Mxfp8Sm120(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
at::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias,
int64_t m,
int64_t n,
int64_t k,
cudaStream_t stream) {
typename Mxfp6Mxfp8GemmSm120::Gemm gemm;
auto arguments = args_from_options_mxfp6_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k);
size_t workspace_size = Mxfp6Mxfp8GemmSm120::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
constexpr auto FP6_FP8_TYPE = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e8m0fnu;
void cutlass_scaled_mxfp6_mxfp8_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias) {
CHECK_INPUT(A, FP6_FP8_TYPE, "a");
CHECK_INPUT(B, FP6_FP8_TYPE, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
// TORCH_CHECK(
// A.sizes()[1] == B.sizes()[1],
// "a and b shapes cannot be multiplied (",
// A.sizes()[0],
// "x",
// A.sizes()[1],
// " and ",
// B.sizes()[0],
// "x",
// B.sizes()[1],
// ")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1];
constexpr int alignment_a = 16;
constexpr int alignment_b = 128;
TORCH_CHECK(
k % alignment_a == 0,
"Expected k to be divisible by ",
alignment_a,
", but got a shape: (",
A.sizes()[0],
"x",
A.sizes()[1],
"), k: ",
k,
".");
TORCH_CHECK(
n % alignment_b == 0,
"Expected n to be divisible by ",
alignment_b,
", but got b shape: (",
B.sizes()[0],
"x",
B.sizes()[1],
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
int rounded_n = round_up(n, 128);
// Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an
// integer.
int rounded_k = round_up(k / 32, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(
A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0],
"x",
A_sf.sizes()[1],
" and ",
B_sf.sizes()[0],
"x",
B_sf.sizes()[1],
")");
TORCH_CHECK(
A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (",
rounded_m,
"x",
rounded_k,
"), but got a shape (",
A_sf.sizes()[0],
"x",
A_sf.sizes()[1],
")");
TORCH_CHECK(
B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (",
rounded_n,
"x",
rounded_k,
"), but got a shape (",
B_sf.sizes()[0],
"x",
B_sf.sizes()[1],
")");
auto out_dtype = D.dtype();
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
runGemmMxfp6Mxfp8Sm120(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream);
}
...@@ -103,7 +103,7 @@ struct Fp4GemmSm120 { ...@@ -103,7 +103,7 @@ struct Fp4GemmSm120 {
// Populates a Gemm::Arguments structure from the given commandline options // Populates a Gemm::Arguments structure from the given commandline options
typename Fp4GemmSm120::Gemm::Arguments args_from_options( typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
at::Tensor& D, at::Tensor& D,
at::Tensor const& A, at::Tensor const& A,
at::Tensor const& B, at::Tensor const& B,
...@@ -176,7 +176,7 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options( ...@@ -176,7 +176,7 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options(
} }
void runGemm( void runGemmNvfp4Sm120(
at::Tensor& D, at::Tensor& D,
at::Tensor const& A, at::Tensor const& A,
at::Tensor const& B, at::Tensor const& B,
...@@ -190,7 +190,7 @@ void runGemm( ...@@ -190,7 +190,7 @@ void runGemm(
cudaStream_t stream) { cudaStream_t stream) {
typename Fp4GemmSm120::Gemm gemm; typename Fp4GemmSm120::Gemm gemm;
auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); auto arguments = args_from_options_nvfp4_nvfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k);
size_t workspace_size = Fp4GemmSm120::Gemm::get_workspace_size(arguments); size_t workspace_size = Fp4GemmSm120::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options); auto workspace = torch::empty(workspace_size, workspace_options);
...@@ -308,5 +308,5 @@ void cutlass_scaled_fp4_mm_sm120( ...@@ -308,5 +308,5 @@ void cutlass_scaled_fp4_mm_sm120(
at::cuda::CUDAGuard device_guard{(char)A.get_device()}; at::cuda::CUDAGuard device_guard{(char)A.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
runGemm(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream); runGemmNvfp4Sm120(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream);
} }
...@@ -55,3 +55,13 @@ void cutlass_scaled_fp4_mm_sm120( ...@@ -55,3 +55,13 @@ void cutlass_scaled_fp4_mm_sm120(
void scaled_fp4_quant_sm120( void scaled_fp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf); torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf);
void cutlass_scaled_mxfp6_mxfp8_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias);
...@@ -2,14 +2,6 @@ import torch ...@@ -2,14 +2,6 @@ import torch
def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
"""
mat_a: (m, k) cutlass::float_e2m1_t
mat_b: (n, k) cutlass::float_e2m1_t
scales_a: (m, 1) cutlass::float_ue4m3_t
scales_b: (n, 1) cutlass::float_ue4m3_t
alpha: (1, 1) float
bias: (m, n) cutlass::bfloat16_t
"""
m, n = mat_a.shape[0], mat_b.shape[0] m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias) torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
...@@ -61,3 +53,10 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor): ...@@ -61,3 +53,10 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(output, input, output_scale, input_global_scale) torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(output, input, output_scale, input_global_scale)
output_scale = output_scale.view(torch.float8_e4m3fn) output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale return output, output_scale
def cutlass_scaled_mxfp6_mxfp8_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_mxfp6_mxfp8_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
return out
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm
def test_cutlass_scaled_mxfp6_mxfp8_mm_sm120():
m, k, n = 1024, 2048, 4096
input_shape = (m, k)
weight_shape = (n, k)
input_tensor_quant = (torch.rand((input_shape[0], input_shape[1]), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((weight_shape[0], weight_shape[1] * 3 // 4), device="cuda") * 10).to(torch.uint8)
print(f"shape: {input_tensor_quant.shape}, {weight.shape}")
input_tensor_scale = torch.rand((input_shape[0], input_shape[1] // 32), device="cuda").to(torch.float8_e8m0fnu)
weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 32, device="cuda").to(torch.float8_e8m0fnu)
print(f"shape: {input_tensor_scale.shape}, {weight_scale.shape}")
alpha = torch.tensor(0.0002765655517578125, device="cuda", dtype=torch.float32)
bias = None
out = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
print(f"out: {out}, shape: {out.shape}")
if __name__ == "__main__":
test_cutlass_scaled_mxfp6_mxfp8_mm_sm120()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment