Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3fb4b5fa
Commit
3fb4b5fa
authored
Mar 23, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.18.0' into v0.18.0-ori
parents
bcf25339
89138b21
Changes
488
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1435 additions
and
262 deletions
+1435
-262
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu
+60
-0
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh
+141
-0
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh
+179
-0
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh
+127
-0
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu
+60
-0
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh
+414
-0
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
.../permute_unpermute_kernels/moe_permute_unpermute_kernel.h
+1
-1
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
...ermute_unpermute_kernels/moe_permute_unpermute_kernel.inl
+8
-9
csrc/moe/router_gemm.cu
csrc/moe/router_gemm.cu
+52
-0
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+8
-0
csrc/ops.h
csrc/ops.h
+30
-13
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+1
-1
csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
...quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
+21
-19
csrc/quantization/fp4/nvfp4_experts_quant.cu
csrc/quantization/fp4/nvfp4_experts_quant.cu
+2
-2
csrc/quantization/fp4/nvfp4_quant_entry.cu
csrc/quantization/fp4/nvfp4_quant_entry.cu
+34
-3
csrc/quantization/fp4/nvfp4_quant_kernels.cu
csrc/quantization/fp4/nvfp4_quant_kernels.cu
+21
-19
csrc/quantization/fp4/nvfp4_utils.cuh
csrc/quantization/fp4/nvfp4_utils.cuh
+29
-133
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
.../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
+53
-24
csrc/quantization/fused_kernels/layernorm_utils.cuh
csrc/quantization/fused_kernels/layernorm_utils.cuh
+61
-37
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh
...ization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh
+133
-1
No files found.
Too many changes to show.
To preserve performance only
488 of 488+
files are displayed.
Plain diff
Email patch
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu
0 → 100644
View file @
3fb4b5fa
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
#include <torch/all.h>
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
void
cutlass_mxfp8_grouped_mm
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
sfa
,
const
torch
::
Tensor
&
sfb
,
torch
::
Tensor
&
d
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
blockscale_offsets
)
{
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have shape (num_experts, 3)"
);
TORCH_CHECK
(
problem_sizes
.
size
(
0
)
==
expert_offsets
.
size
(
0
),
"Number of experts in problem_sizes must match expert_offsets"
);
TORCH_CHECK
(
problem_sizes
.
dtype
()
==
torch
::
kInt32
,
"problem_sizes must be int32"
);
TORCH_CHECK
(
expert_offsets
.
dtype
()
==
torch
::
kInt32
,
"expert_offsets must be int32"
);
TORCH_CHECK
(
blockscale_offsets
.
dtype
()
==
torch
::
kInt32
,
"blockscale_offsets must be int32"
);
TORCH_CHECK
(
a
.
dim
()
==
2
,
"a must be a 2D tensor of shape (num_tokens, k)"
);
TORCH_CHECK
(
b
.
dim
()
==
3
,
"b must be a 3D tensor of shape (num_experts, k, n)"
);
TORCH_CHECK
(
a
.
size
(
1
)
==
b
.
size
(
1
)
&&
a
.
size
(
1
)
%
128
==
0
,
"k should align 128"
);
TORCH_CHECK
(
b
.
size
(
2
)
%
128
==
0
,
"n should align 128"
);
TORCH_CHECK
(
a
.
strides
()[
1
]
==
1
,
"a must be row major"
);
TORCH_CHECK
(
b
.
strides
()[
1
]
==
1
,
"b must be column major"
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
d
.
dtype
()
==
torch
::
kBFloat16
)
{
expert_specialization
::
cutlass_mxfp8_grouped_mm_dispatch_out_dtype
<
cutlass
::
bfloat16_t
>
(
a
,
b
,
sfa
,
sfb
,
d
,
problem_sizes
,
expert_offsets
,
blockscale_offsets
,
stream
);
}
else
if
(
d
.
dtype
()
==
torch
::
kFloat16
)
{
expert_specialization
::
cutlass_mxfp8_grouped_mm_dispatch_out_dtype
<
cutlass
::
half_t
>
(
a
,
b
,
sfa
,
sfb
,
d
,
problem_sizes
,
expert_offsets
,
blockscale_offsets
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"dtype must be kFloat16 or kBFloat16"
);
}
#else
TORCH_CHECK
(
false
,
"No implemented cutlass_mxfp8_grouped_mm for "
"current device"
);
#endif
}
#include "core/registration.h"
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"cutlass_mxfp8_grouped_mm"
,
cutlass_mxfp8_grouped_mm
);
}
\ No newline at end of file
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh
0 → 100644
View file @
3fb4b5fa
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_functor.cuh
#pragma once
#include <cuda.h>
#include "cute/tensor.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
namespace
expert_specialization
{
using
namespace
cute
;
template
<
typename
GemmTraits
>
struct
CutlassMxfp8GroupedMmOffsetFunctor
{
using
Gemm
=
typename
GemmTraits
::
Gemm
;
using
ElementA
=
typename
Gemm
::
ElementA
;
using
ElementB
=
typename
Gemm
::
ElementB
;
using
ElementSF
=
typename
GemmTraits
::
ElementSF
;
using
ElementD
=
typename
GemmTraits
::
ElementOutput
;
// Input
int
*
expert_offsets
{
nullptr
};
int
*
blockscale_offsets
{
nullptr
};
// Output
ElementA
*
a_base
{
nullptr
};
ElementB
*
b_base
{
nullptr
};
ElementSF
*
sfa_base
{
nullptr
};
ElementSF
*
sfb_base
{
nullptr
};
ElementD
*
d_base
{
nullptr
};
ElementA
**
a_offsets
{
nullptr
};
ElementB
**
b_offsets
{
nullptr
};
ElementSF
**
sfa_offsets
{
nullptr
};
ElementSF
**
sfb_offsets
{
nullptr
};
ElementD
**
d_offsets
{
nullptr
};
CutlassMxfp8GroupedMmOffsetFunctor
()
=
default
;
CutlassMxfp8GroupedMmOffsetFunctor
(
int
*
_expert_offsets
,
int
*
_blockscale_offsets
,
ElementA
*
_a_base
,
ElementB
*
_b_base
,
ElementSF
*
_sfa_base
,
ElementSF
*
_sfb_base
,
ElementD
*
_d_base
,
ElementA
**
_a_offsets
,
ElementB
**
_b_offsets
,
ElementSF
**
_sfa_offsets
,
ElementSF
**
_sfb_offsets
,
ElementD
**
_d_offsets
)
:
expert_offsets
{
_expert_offsets
},
blockscale_offsets
{
_blockscale_offsets
},
a_base
(
_a_base
),
b_base
(
_b_base
),
sfa_base
(
_sfa_base
),
sfb_base
(
_sfb_base
),
d_base
(
_d_base
),
a_offsets
(
_a_offsets
),
b_offsets
(
_b_offsets
),
sfa_offsets
(
_sfa_offsets
),
sfb_offsets
(
_sfb_offsets
),
d_offsets
(
_d_offsets
)
{}
void
CUTE_DEVICE
operator
()(
int64_t
expert_id
,
int
m
,
int
n
,
int
k
)
{
int64_t
expert_offset
=
static_cast
<
int64_t
>
(
expert_offsets
[
expert_id
]);
int64_t
blockscale_offset
=
static_cast
<
int64_t
>
(
blockscale_offsets
[
expert_id
]);
int64_t
a_stride
=
expert_offset
*
k
;
int64_t
b_stride
=
expert_id
*
k
*
n
;
int64_t
d_stride
=
expert_offset
*
n
;
int64_t
sfa_stride
=
blockscale_offset
*
(
k
/
32
);
int64_t
sfb_stride
=
expert_id
*
n
*
(
k
/
32
);
a_offsets
[
expert_id
]
=
a_base
+
a_stride
;
b_offsets
[
expert_id
]
=
b_base
+
b_stride
;
sfa_offsets
[
expert_id
]
=
sfa_base
+
sfa_stride
;
sfb_offsets
[
expert_id
]
=
sfb_base
+
sfb_stride
;
d_offsets
[
expert_id
]
=
d_base
+
d_stride
;
}
};
template
<
typename
GemmTraits
>
struct
CutlassMxfp8GroupedMmLayoutFunctor
{
using
Sm1xxBlkScaledConfig
=
typename
GemmTraits
::
Sm1xxBlkScaledConfig
;
using
LayoutSFA
=
typename
GemmTraits
::
LayoutSFA
;
using
LayoutSFB
=
typename
GemmTraits
::
LayoutSFB
;
LayoutSFA
*
layout_sfa_base
{
nullptr
};
LayoutSFB
*
layout_sfb_base
{
nullptr
};
CutlassMxfp8GroupedMmLayoutFunctor
()
=
default
;
CutlassMxfp8GroupedMmLayoutFunctor
(
LayoutSFA
*
_layout_sfa_base
,
LayoutSFB
*
_layout_sfb_base
)
:
layout_sfa_base
(
_layout_sfa_base
),
layout_sfb_base
(
_layout_sfb_base
)
{}
void
CUTE_DEVICE
operator
()(
int64_t
expert_id
,
int
m
,
int
n
,
int
k
)
{
LayoutSFA
*
layout_sfa_ptr
=
layout_sfa_base
+
expert_id
;
LayoutSFB
*
layout_sfb_ptr
=
layout_sfb_base
+
expert_id
;
*
layout_sfa_ptr
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
*
layout_sfb_ptr
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
}
};
template
<
typename
GemmTraits
>
struct
CutlassMxfp8GroupedMmStrideFunctor
{
using
StrideA
=
typename
GemmTraits
::
StrideA
;
using
StrideB
=
typename
GemmTraits
::
StrideB
;
using
StrideD
=
typename
GemmTraits
::
StrideD
;
StrideA
*
stride_A_base
{
nullptr
};
StrideB
*
stride_B_base
{
nullptr
};
StrideD
*
stride_D_base
{
nullptr
};
CutlassMxfp8GroupedMmStrideFunctor
()
=
default
;
CutlassMxfp8GroupedMmStrideFunctor
(
StrideA
*
_stride_A_base
,
StrideB
*
_stride_B_base
,
StrideD
*
_stride_D_base
)
:
stride_A_base
(
_stride_A_base
),
stride_B_base
(
_stride_B_base
),
stride_D_base
(
_stride_D_base
)
{}
void
CUTE_DEVICE
operator
()(
int64_t
expert_id
,
int
m
,
int
n
,
int
k
)
{
StrideA
*
stride_A
=
stride_A_base
+
expert_id
;
StrideB
*
stride_B
=
stride_B_base
+
expert_id
;
StrideD
*
stride_D
=
stride_D_base
+
expert_id
;
*
stride_A
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
{
m
,
k
,
1
});
*
stride_B
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
{
n
,
k
,
1
});
*
stride_D
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
{
m
,
n
,
1
});
}
};
template
<
typename
OffsetFunctor
,
typename
LayoutFunctor
,
typename
StrideFunctor
>
__global__
void
cutlassMxfp8GroupedMmPreComputeKernel
(
int
*
problem_sizes
,
OffsetFunctor
offset_functor
,
LayoutFunctor
layout_functor
,
StrideFunctor
stride_functor
)
{
int64_t
expert_id
=
static_cast
<
int64_t
>
(
threadIdx
.
x
);
int
m
=
problem_sizes
[
expert_id
*
3
+
0
];
int
n
=
problem_sizes
[
expert_id
*
3
+
1
];
int
k
=
problem_sizes
[
expert_id
*
3
+
2
];
offset_functor
(
expert_id
,
m
,
n
,
k
);
layout_functor
(
expert_id
,
m
,
n
,
k
);
stride_functor
(
expert_id
,
m
,
n
,
k
);
}
}
// namespace expert_specialization
\ No newline at end of file
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh
0 → 100644
View file @
3fb4b5fa
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <cassert>
#include <iostream>
#include <string>
#include "cute/tensor.hpp"
#include "cutlass_mxfp8_grouped_mm_functor.cuh"
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
namespace
expert_specialization
{
template
<
typename
GemmTraits
>
void
cutlass_mxfp8_grouped_mm_pre_compute
(
torch
::
Tensor
&
a_ptrs
,
torch
::
Tensor
&
b_ptrs
,
torch
::
Tensor
&
sfa_ptrs
,
torch
::
Tensor
&
sfb_ptrs
,
torch
::
Tensor
&
d_ptrs
,
torch
::
Tensor
&
stride_a
,
torch
::
Tensor
&
stride_b
,
torch
::
Tensor
&
stride_d
,
torch
::
Tensor
&
layout_sfa
,
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
sfa
,
const
torch
::
Tensor
&
sfb
,
const
torch
::
Tensor
&
d
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
blockscale_offsets
,
cudaStream_t
stream
)
{
using
OffsetFunctor
=
CutlassMxfp8GroupedMmOffsetFunctor
<
GemmTraits
>
;
using
ElementA
=
typename
OffsetFunctor
::
ElementA
;
using
ElementB
=
typename
OffsetFunctor
::
ElementB
;
using
ElementSF
=
typename
OffsetFunctor
::
ElementSF
;
using
ElementD
=
typename
OffsetFunctor
::
ElementD
;
using
LayoutFunctor
=
CutlassMxfp8GroupedMmLayoutFunctor
<
GemmTraits
>
;
using
LayoutSFA
=
typename
LayoutFunctor
::
LayoutSFA
;
using
LayoutSFB
=
typename
LayoutFunctor
::
LayoutSFB
;
using
StrideFunctor
=
CutlassMxfp8GroupedMmStrideFunctor
<
GemmTraits
>
;
using
StrideA
=
typename
StrideFunctor
::
StrideA
;
using
StrideB
=
typename
StrideFunctor
::
StrideB
;
using
StrideD
=
typename
StrideFunctor
::
StrideD
;
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
TORCH_CHECK
(
num_experts
<=
1024
,
"Number of experts cannot exceed 1024, the maximum number of "
"threads per block."
);
OffsetFunctor
offset_functor
(
reinterpret_cast
<
int
*>
(
expert_offsets
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
blockscale_offsets
.
data_ptr
()),
reinterpret_cast
<
ElementA
*>
(
a
.
data_ptr
()),
reinterpret_cast
<
ElementB
*>
(
b
.
data_ptr
()),
reinterpret_cast
<
ElementSF
*>
(
sfa
.
data_ptr
()),
reinterpret_cast
<
ElementSF
*>
(
sfb
.
data_ptr
()),
reinterpret_cast
<
ElementD
*>
(
d
.
data_ptr
()),
reinterpret_cast
<
ElementA
**>
(
a_ptrs
.
data_ptr
()),
reinterpret_cast
<
ElementB
**>
(
b_ptrs
.
data_ptr
()),
reinterpret_cast
<
ElementSF
**>
(
sfa_ptrs
.
data_ptr
()),
reinterpret_cast
<
ElementSF
**>
(
sfb_ptrs
.
data_ptr
()),
reinterpret_cast
<
ElementD
**>
(
d_ptrs
.
data_ptr
()));
LayoutFunctor
layout_functor
(
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
()));
StrideFunctor
stride_functor
(
reinterpret_cast
<
StrideA
*>
(
stride_a
.
data_ptr
()),
reinterpret_cast
<
StrideB
*>
(
stride_b
.
data_ptr
()),
reinterpret_cast
<
StrideD
*>
(
stride_d
.
data_ptr
()));
cutlassMxfp8GroupedMmPreComputeKernel
<<<
1
,
num_experts
,
0
,
stream
>>>
(
static_cast
<
int
*>
(
problem_sizes
.
data_ptr
()),
offset_functor
,
layout_functor
,
stride_functor
);
}
template
<
typename
GemmTraits
>
void
cutlass_mxfp8_grouped_mm
(
const
torch
::
Tensor
&
a_ptrs
,
const
torch
::
Tensor
&
b_ptrs
,
const
torch
::
Tensor
&
sfa_ptrs
,
const
torch
::
Tensor
&
sfb_ptrs
,
const
torch
::
Tensor
&
d_ptrs
,
const
torch
::
Tensor
&
stride_a
,
const
torch
::
Tensor
&
stride_b
,
const
torch
::
Tensor
&
stride_d
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
problem_sizes
,
cudaStream_t
stream
)
{
using
Gemm
=
typename
GemmTraits
::
Gemm
;
using
ElementA
=
typename
Gemm
::
ElementA
;
using
ElementB
=
typename
Gemm
::
ElementB
;
using
ElementSF
=
typename
GemmTraits
::
ElementSF
;
using
ElementD
=
typename
GemmTraits
::
ElementOutput
;
using
StrideA
=
typename
GemmTraits
::
StrideA
;
using
StrideB
=
typename
GemmTraits
::
StrideB
;
using
StrideD
=
typename
GemmTraits
::
StrideD
;
using
LayoutSFA
=
typename
GemmTraits
::
LayoutSFA
;
using
LayoutSFB
=
typename
GemmTraits
::
LayoutSFB
;
using
UnderlyingProblemShape
=
typename
GemmTraits
::
ProblemShape
::
UnderlyingProblemShape
;
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
c10
::
cuda
::
current_device
();
hw_info
.
sm_count
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
hw_info
.
cluster_shape
=
GemmTraits
::
MMAConfig
::
preferred_cluster
;
hw_info
.
cluster_shape_fallback
=
GemmTraits
::
MMAConfig
::
fallback_cluster
;
int
num_experts
=
(
int
)
problem_sizes
.
size
(
0
);
UnderlyingProblemShape
*
underlying_problem_shape
=
reinterpret_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
typename
Gemm
::
Arguments
arguments
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
underlying_problem_shape
,
nullptr
},
{
reinterpret_cast
<
const
ElementA
**>
(
a_ptrs
.
data_ptr
()),
reinterpret_cast
<
StrideA
*>
(
stride_a
.
data_ptr
()),
reinterpret_cast
<
const
ElementB
**>
(
b_ptrs
.
data_ptr
()),
reinterpret_cast
<
StrideB
*>
(
stride_b
.
data_ptr
()),
reinterpret_cast
<
const
ElementSF
**>
(
sfa_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
reinterpret_cast
<
const
ElementSF
**>
(
sfb_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())},
{{},
nullptr
,
nullptr
,
reinterpret_cast
<
ElementD
**>
(
d_ptrs
.
data_ptr
()),
reinterpret_cast
<
StrideD
*>
(
stride_d
.
data_ptr
())},
hw_info
,
{}
// Scheduler
};
Gemm
gemm
;
auto
can_implement_status
=
gemm
.
can_implement
(
arguments
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
torch
::
TensorOptions
options_uint8
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
d_ptrs
.
device
());
size_t
workspace_size
=
gemm
.
get_workspace_size
(
arguments
);
torch
::
Tensor
workspace
=
torch
::
empty
(
workspace_size
,
options_uint8
);
auto
status
=
gemm
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
status
=
gemm
.
run
(
stream
,
nullptr
,
true
);
// Enable PDL
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
template
<
typename
OutType
>
void
cutlass_mxfp8_grouped_mm_dispatch_out_dtype
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
sfa
,
const
torch
::
Tensor
&
sfb
,
torch
::
Tensor
&
d
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
blockscale_offsets
,
cudaStream_t
stream
)
{
int
num_experts
=
(
int
)
problem_sizes
.
size
(
0
);
torch
::
TensorOptions
options_int64
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
TensorOptions
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
a
.
device
());
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int64
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int64
);
torch
::
Tensor
sfa_ptrs
=
torch
::
empty
(
num_experts
,
options_int64
);
torch
::
Tensor
sfb_ptrs
=
torch
::
empty
(
num_experts
,
options_int64
);
torch
::
Tensor
d_ptrs
=
torch
::
empty
(
num_experts
,
options_int64
);
torch
::
Tensor
stride_a
=
torch
::
empty
(
num_experts
,
options_int64
);
torch
::
Tensor
stride_b
=
torch
::
empty
(
num_experts
,
options_int64
);
torch
::
Tensor
stride_d
=
torch
::
empty
(
num_experts
,
options_int64
);
torch
::
Tensor
layout_sfa
=
torch
::
empty
({
num_experts
,
5
},
options_int32
);
torch
::
Tensor
layout_sfb
=
torch
::
empty
({
num_experts
,
5
},
options_int32
);
using
GemmTraits
=
CutlassMxfp8GroupedMmGemmTraits
<
MMA1SMConfig
,
OutType
>
;
cutlass_mxfp8_grouped_mm_pre_compute
<
GemmTraits
>
(
a_ptrs
,
b_ptrs
,
sfa_ptrs
,
sfb_ptrs
,
d_ptrs
,
stride_a
,
stride_b
,
stride_d
,
layout_sfa
,
layout_sfb
,
a
,
b
,
sfa
,
sfb
,
d
,
problem_sizes
,
expert_offsets
,
blockscale_offsets
,
stream
);
cutlass_mxfp8_grouped_mm
<
GemmTraits
>
(
a_ptrs
,
b_ptrs
,
sfa_ptrs
,
sfb_ptrs
,
d_ptrs
,
stride_a
,
stride_b
,
stride_d
,
layout_sfa
,
layout_sfb
,
problem_sizes
,
stream
);
}
}
// namespace expert_specialization
\ No newline at end of file
csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh
0 → 100644
View file @
3fb4b5fa
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_traits.cuh
#pragma once
// Misc
#include "cute/tensor.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/cutlass.h"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_size.h"
// Collective Builder
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
// Integration
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
namespace
expert_specialization
{
using
namespace
cute
;
// Different configs for 1SM and 2SM MMA kernel
struct
MMA1SMConfig
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized1Sm
;
const
static
dim3
preferred_cluster
;
const
static
dim3
fallback_cluster
;
};
const
dim3
MMA1SMConfig
::
preferred_cluster
(
1
,
4
,
1
);
const
dim3
MMA1SMConfig
::
fallback_cluster
(
1
,
2
,
1
);
template
<
typename
_MMAConfig
,
typename
OutputDtype
>
struct
CutlassMxfp8GroupedMmGemmTraits
{
using
MMAConfig
=
_MMAConfig
;
using
ElementInput
=
cutlass
::
float_e4m3_t
;
using
ElementOutput
=
OutputDtype
;
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int
,
int
,
int
>>
;
// A matrix configuration
using
ElementA
=
cutlass
::
mx_float8_t
<
ElementInput
>
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
constexpr
static
int
AlignmentA
=
32
;
// B matrix configuration
using
ElementB
=
cutlass
::
mx_float8_t
<
ElementInput
>
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
constexpr
static
int
AlignmentB
=
32
;
// C/D matrix configuration
using
ElementC
=
void
;
using
ElementD
=
ElementOutput
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
constexpr
static
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
constexpr
static
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
ElementAccumulator
=
float
;
static
constexpr
auto
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
using
CustomEVTIdentity
=
// acc
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
epilogue
::
thread
::
Identity
,
ElementD
,
ElementAccumulator
,
RoundStyle
>
,
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
// Core kernel configurations
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
// Runtime Cluster Shape
using
ClusterShape
=
Shape
<
int32_t
,
int32_t
,
_1
>
;
// Define Epilogue
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
typename
MMAConfig
::
MmaTileShape
,
ClusterShape
,
Shape
<
_64
,
_64
>
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutC
*
,
AlignmentC
,
ElementD
,
LayoutD
*
,
AlignmentD
,
typename
MMAConfig
::
EpilogueSchedule
,
CustomEVTIdentity
>::
CollectiveOp
;
// Define Mainloop
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
LayoutA
*
,
AlignmentA
,
ElementB
,
LayoutB
*
,
AlignmentB
,
ElementAccumulator
,
typename
MMAConfig
::
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
typename
MMAConfig
::
KernelSchedule
>::
CollectiveOp
;
// Define GemmKernel
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
ElementSF
=
typename
Gemm
::
GemmKernel
::
ElementSF
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
InternalStrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
InternalStrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
InternalStrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
InternalStrideD
;
using
LayoutSFA
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFB
;
using
Sm1xxBlkScaledConfig
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
};
}
// namespace expert_specialization
\ No newline at end of file
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu
0 → 100644
View file @
3fb4b5fa
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
#include <torch/all.h>
#include "mxfp8_experts_quant.cuh"
void
mxfp8_experts_quant
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
blockscale_offsets
,
torch
::
Tensor
&
quant_output
,
torch
::
Tensor
&
scale_factor
)
{
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK
(
input
.
dim
()
==
2
,
"input must be 2D tensor"
);
TORCH_CHECK
(
input
.
size
(
1
)
%
128
==
0
,
"k must align to 128"
);
TORCH_CHECK
(
input
.
strides
()[
1
]
==
1
,
"input must be row major"
);
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
dtype
()
==
torch
::
kInt32
,
"problem_sizes must be int32"
);
TORCH_CHECK
(
expert_offsets
.
dtype
()
==
torch
::
kInt32
,
"expert_offsets must be int32"
);
TORCH_CHECK
(
blockscale_offsets
.
dtype
()
==
torch
::
kInt32
,
"blockscale_offsets must be int32"
);
auto
groups
=
problem_sizes
.
size
(
0
);
TORCH_CHECK
(
expert_offsets
.
dim
()
==
1
&&
expert_offsets
.
size
(
0
)
==
groups
,
"expert_offsets must be 1D and have size equal to the number of groups"
);
TORCH_CHECK
(
blockscale_offsets
.
dim
()
==
1
&&
blockscale_offsets
.
size
(
0
)
==
groups
,
"blockscale_offsets must be 1D and have size equal to the number of "
"groups"
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
expert_specialization
::
launch_mxfp8_experts_quant
<
__nv_bfloat16
>
(
input
,
problem_sizes
,
expert_offsets
,
blockscale_offsets
,
quant_output
,
scale_factor
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
expert_specialization
::
launch_mxfp8_experts_quant
<
__half
>
(
input
,
problem_sizes
,
expert_offsets
,
blockscale_offsets
,
quant_output
,
scale_factor
);
}
else
{
TORCH_CHECK
(
false
,
"dtype must be kFloat16 or kBFloat16"
);
}
#else
TORCH_CHECK
(
false
,
"No implemented mxfp8_experts_quant for "
"current device"
);
#endif
}
#include "core/registration.h"
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"mxfp8_experts_quant"
,
mxfp8_experts_quant
);
}
\ No newline at end of file
csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh
0 → 100644
View file @
3fb4b5fa
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/all.h>
#include <cuda/ptx>
#include "cute/tensor.hpp"
namespace
expert_specialization
{
using
namespace
cute
;
constexpr
uint32_t
THREAD_BLOCK_SIZE
=
128
;
constexpr
uint32_t
WARP_SIZE
=
32
;
constexpr
int
BLOCK_M
=
128
;
constexpr
int
BLOCK_K
=
128
;
using
ThrLayout
=
Layout
<
Shape
<
_16
,
_8
>
,
Stride
<
_8
,
_1
>>
;
using
ValLayout
=
Layout
<
Shape
<
_1
,
_16
>>
;
using
SfR2SThrLayout
=
Layout
<
Shape
<
_16
,
_4
>
,
Stride
<
_4
,
_1
>>
;
using
SfR2SValLayout
=
Layout
<
Shape
<
_1
,
_1
>>
;
using
ScaleFactorTileLayout
=
Layout
<
Shape
<
Shape
<
_32
,
_4
>
,
_4
>
,
Stride
<
Stride
<
_16
,
_4
>
,
_1
>>
;
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
// Some code references TRT-LLM:
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/quantization.cuh
template
<
typename
FragmentS
,
typename
FragmentD
>
__inline__
__device__
uint8_t
cvt_warp_fp16_to_mxfp8
(
FragmentS
&
fragment_s
,
FragmentD
&
fragment_d
)
{
using
FragmentSLayout
=
typename
FragmentS
::
layout_type
;
using
FragmentDLayout
=
typename
FragmentD
::
layout_type
;
FragmentSLayout
fragment_s_layout
;
FragmentDLayout
fragment_d_layout
;
static_assert
(
is_static
<
FragmentSLayout
>::
value
&&
size
(
fragment_s_layout
)
==
16
);
static_assert
(
is_static
<
FragmentDLayout
>::
value
&&
size
(
fragment_d_layout
)
==
16
);
constexpr
int
eles_per_thr
=
16
;
using
ValType
=
typename
FragmentS
::
element_type
;
using
VecType
=
std
::
conditional_t
<
std
::
is_same_v
<
ValType
,
__nv_bfloat16
>
,
__nv_bfloat162
,
__half2
>
;
VecType
vec
[
8
];
// Assign vals
vec
[
0
].
x
=
fragment_s
(
Int
<
0
>
{});
vec
[
0
].
y
=
fragment_s
(
Int
<
1
>
{});
vec
[
1
].
x
=
fragment_s
(
Int
<
2
>
{});
vec
[
1
].
y
=
fragment_s
(
Int
<
3
>
{});
vec
[
2
].
x
=
fragment_s
(
Int
<
4
>
{});
vec
[
2
].
y
=
fragment_s
(
Int
<
5
>
{});
vec
[
3
].
x
=
fragment_s
(
Int
<
6
>
{});
vec
[
3
].
y
=
fragment_s
(
Int
<
7
>
{});
vec
[
4
].
x
=
fragment_s
(
Int
<
8
>
{});
vec
[
4
].
y
=
fragment_s
(
Int
<
9
>
{});
vec
[
5
].
x
=
fragment_s
(
Int
<
10
>
{});
vec
[
5
].
y
=
fragment_s
(
Int
<
11
>
{});
vec
[
6
].
x
=
fragment_s
(
Int
<
12
>
{});
vec
[
6
].
y
=
fragment_s
(
Int
<
13
>
{});
vec
[
7
].
x
=
fragment_s
(
Int
<
14
>
{});
vec
[
7
].
y
=
fragment_s
(
Int
<
15
>
{});
auto
local_max
=
__habs2
(
vec
[
0
]);
for
(
int
i
=
1
;
i
<
eles_per_thr
/
2
;
i
++
)
{
local_max
=
__hmax2
(
__habs2
(
vec
[
i
]),
local_max
);
}
local_max
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
local_max
,
1
),
local_max
);
// Get the final absolute maximum values.
float
block_max
(
0.0
f
);
if
constexpr
(
std
::
is_same_v
<
ValType
,
__nv_bfloat16
>
)
{
block_max
=
__bfloat162float
(
__hmax
(
local_max
.
x
,
local_max
.
y
));
}
else
{
block_max
=
__half2float
(
__hmax
(
local_max
.
x
,
local_max
.
y
));
}
// Get the SF (max value of the vector / max value of mxfp8).
float
sf_val
=
block_max
*
reciprocal_approximate_ftz
(
448.0
f
);
// 8 bits representation of the SF.
uint8_t
fp8_sf_val
;
__nv_fp8_e8m0
tmp_sf_val
;
tmp_sf_val
.
__x
=
__nv_cvt_float_to_e8m0
(
sf_val
,
__NV_SATFINITE
,
cudaRoundPosInf
);
sf_val
=
static_cast
<
float
>
(
tmp_sf_val
);
fp8_sf_val
=
tmp_sf_val
.
__x
;
// Get the output scale (reciprocal of the SFValue).
float
output_scale
=
block_max
!=
0.
f
?
reciprocal_approximate_ftz
(
sf_val
)
:
0.0
f
;
// Convert the input to float.
float2
fp2_vals
[
eles_per_thr
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
eles_per_thr
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
ValType
,
__half
>
)
{
fp2_vals
[
i
]
=
__half22float2
(
vec
[
i
]);
}
else
{
fp2_vals
[
i
]
=
__bfloat1622float2
(
vec
[
i
]);
}
fp2_vals
[
i
].
x
*=
output_scale
;
fp2_vals
[
i
].
y
*=
output_scale
;
}
union
{
uint8_t
bytes
[
16
];
__nv_fp8x2_e4m3
elts
[
8
];
}
u
;
u
.
elts
[
0
]
=
__nv_fp8x2_e4m3
(
fp2_vals
[
0
]);
u
.
elts
[
1
]
=
__nv_fp8x2_e4m3
(
fp2_vals
[
1
]);
u
.
elts
[
2
]
=
__nv_fp8x2_e4m3
(
fp2_vals
[
2
]);
u
.
elts
[
3
]
=
__nv_fp8x2_e4m3
(
fp2_vals
[
3
]);
u
.
elts
[
4
]
=
__nv_fp8x2_e4m3
(
fp2_vals
[
4
]);
u
.
elts
[
5
]
=
__nv_fp8x2_e4m3
(
fp2_vals
[
5
]);
u
.
elts
[
6
]
=
__nv_fp8x2_e4m3
(
fp2_vals
[
6
]);
u
.
elts
[
7
]
=
__nv_fp8x2_e4m3
(
fp2_vals
[
7
]);
fragment_d
(
Int
<
0
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
0
]);
fragment_d
(
Int
<
1
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
1
]);
fragment_d
(
Int
<
2
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
2
]);
fragment_d
(
Int
<
3
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
3
]);
fragment_d
(
Int
<
4
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
4
]);
fragment_d
(
Int
<
5
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
5
]);
fragment_d
(
Int
<
6
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
6
]);
fragment_d
(
Int
<
7
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
7
]);
fragment_d
(
Int
<
8
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
8
]);
fragment_d
(
Int
<
9
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
9
]);
fragment_d
(
Int
<
10
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
10
]);
fragment_d
(
Int
<
11
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
11
]);
fragment_d
(
Int
<
12
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
12
]);
fragment_d
(
Int
<
13
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
13
]);
fragment_d
(
Int
<
14
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
14
]);
fragment_d
(
Int
<
15
>
{})
=
cutlass
::
float_e4m3_t
::
bitcast
(
u
.
bytes
[
15
]);
return
fp8_sf_val
;
}
template
<
typename
TensorS
,
typename
TensorP
,
typename
TensorD
,
typename
TensorSharedSF
,
typename
TensorSF
,
typename
TiledCopyG2R
,
typename
TiledCopyR2G
,
typename
TiledCopyR2S
>
__inline__
__device__
void
mxfp8_experts_quant_tile
(
TensorS
&
tensor_s
,
TensorP
&
tensor_p
,
TensorD
&
tensor_d
,
TensorSharedSF
&
tensor_shared_sf
,
TensorSF
&
tensor_sf
,
int
m
,
TiledCopyG2R
&
tiled_copy_g2r
,
TiledCopyR2G
&
tiled_copy_r2g
,
TiledCopyR2S
&
tiled_copy_r2s
)
{
static_assert
(
size
(
get
<
0
>
(
typename
TensorS
::
layout_type
{}))
==
128
&&
size
(
get
<
1
>
(
typename
TensorS
::
layout_type
{}))
==
128
&&
stride
(
get
<
1
>
(
typename
TensorS
::
layout_type
{}))
==
1
);
static_assert
(
size
(
get
<
0
>
(
typename
TensorD
::
layout_type
{}))
==
128
&&
size
(
get
<
1
>
(
typename
TensorD
::
layout_type
{}))
==
128
&&
stride
(
get
<
1
>
(
typename
TensorD
::
layout_type
{}))
==
1
);
static_assert
(
size
(
get
<
0
>
(
typename
TensorP
::
layout_type
{}))
==
128
&&
size
(
get
<
1
>
(
typename
TensorP
::
layout_type
{}))
==
128
);
static_assert
(
size
(
get
<
0
>
(
typename
TensorSharedSF
::
layout_type
{}))
==
128
&&
size
(
get
<
1
>
(
typename
TensorSharedSF
::
layout_type
{}))
==
4
);
static_assert
(
size
(
get
<
0
>
(
typename
TensorSF
::
layout_type
{}))
==
128
&&
size
(
get
<
1
>
(
typename
TensorSF
::
layout_type
{}))
==
4
);
using
Tiler_MN
=
typename
TiledCopyG2R
::
Tiler_MN
;
auto
tiler_mn
=
Tiler_MN
{};
static_assert
(
size
<
0
>
(
tiler_mn
)
==
16
&&
size
<
1
>
(
tiler_mn
)
==
128
);
auto
tiled_tensor_s
=
tiled_divide
(
tensor_s
,
tiler_mn
);
auto
tiled_tensor_p
=
tiled_divide
(
tensor_p
,
tiler_mn
);
auto
tiled_tensor_d
=
tiled_divide
(
tensor_d
,
tiler_mn
);
static_assert
(
size
<
2
>
(
tiled_tensor_s
)
==
1
);
static_assert
(
size
<
2
>
(
tiled_tensor_p
)
==
1
);
static_assert
(
size
<
2
>
(
tiled_tensor_d
)
==
1
);
auto
squeeze_tiled_tensor_s
=
take
<
0
,
2
>
(
tiled_tensor_s
);
auto
squeeze_tiled_tensor_p
=
take
<
0
,
2
>
(
tiled_tensor_p
);
auto
squeeze_tiled_tensor_d
=
take
<
0
,
2
>
(
tiled_tensor_d
);
using
SF_Tiler_MN
=
typename
TiledCopyR2S
::
Tiler_MN
;
auto
sf_tiler_mn
=
SF_Tiler_MN
{};
static_assert
(
size
<
0
>
(
sf_tiler_mn
)
==
16
&&
size
<
1
>
(
sf_tiler_mn
)
==
4
);
auto
tiled_tensor_sf
=
tiled_divide
(
tensor_sf
,
sf_tiler_mn
);
auto
tiled_tensor_shared_sf
=
tiled_divide
(
tensor_shared_sf
,
sf_tiler_mn
);
auto
squeeze_tiled_tensor_sf
=
take
<
0
,
2
>
(
tiled_tensor_sf
);
auto
squeeze_tiled_tensor_shared_sf
=
take
<
0
,
2
>
(
tiled_tensor_shared_sf
);
constexpr
int
tile_loop_count
=
size
<
1
>
(
tiled_tensor_s
);
constexpr
int
rows_in_tile
=
16
;
// We don't need to clear shared memory
// clear(squeeze_tiled_tensor_shared_sf);
#pragma unroll 4
for
(
int
t
=
0
;
t
<
tile_loop_count
;
t
++
)
{
if
(
t
*
rows_in_tile
>=
m
)
{
break
;
}
auto
current_copy_tile_s
=
tensor
<
0
>
(
squeeze_tiled_tensor_s
(
_
,
t
));
auto
current_copy_tile_p
=
tensor
<
0
>
(
squeeze_tiled_tensor_p
(
_
,
t
));
auto
current_copy_tile_d
=
tensor
<
0
>
(
squeeze_tiled_tensor_d
(
_
,
t
));
auto
current_copy_tile_sf
=
tensor
<
0
>
(
squeeze_tiled_tensor_sf
(
_
,
t
));
auto
current_copy_tile_shared_sf
=
tensor
<
0
>
(
squeeze_tiled_tensor_shared_sf
(
_
,
t
));
// Global to Register copy
auto
thr_copy_g2r
=
tiled_copy_g2r
.
get_thread_slice
(
threadIdx
.
x
);
auto
thr_tile_g2r_s
=
thr_copy_g2r
.
partition_S
(
current_copy_tile_s
);
auto
thr_tile_g2r_p
=
thr_copy_g2r
.
partition_S
(
current_copy_tile_p
);
auto
input_fragment
=
make_fragment_like
(
thr_tile_g2r_s
);
// Register to Global copy
auto
thr_copy_r2g
=
tiled_copy_r2g
.
get_thread_slice
(
threadIdx
.
x
);
auto
thr_tile_r2g_d
=
thr_copy_r2g
.
partition_D
(
current_copy_tile_d
);
auto
thr_tile_r2g_p
=
thr_copy_r2g
.
partition_D
(
current_copy_tile_p
);
auto
output_fragment
=
make_fragment_like
(
thr_tile_r2g_d
);
// Register to Shared copy
auto
thr_copy_r2s
=
tiled_copy_r2s
.
get_thread_slice
(
threadIdx
.
x
/
2
);
auto
thr_tile_r2s_shared_sf
=
thr_copy_r2s
.
partition_D
(
current_copy_tile_shared_sf
);
auto
shared_sf_fragment
=
make_fragment_like
(
thr_tile_r2s_shared_sf
);
// CopyG2R & convert & CopyR2G
copy_if
(
tiled_copy_g2r
,
thr_tile_g2r_p
,
thr_tile_g2r_s
,
input_fragment
);
uint8_t
fp8_sf_val
=
cvt_warp_fp16_to_mxfp8
(
input_fragment
,
output_fragment
);
copy_if
(
tiled_copy_r2g
,
thr_tile_r2g_p
,
output_fragment
,
thr_tile_r2g_d
);
shared_sf_fragment
[
0
]
=
fp8_sf_val
;
// Before first copy r2s, clear shared memory and wait previous group
if
(
t
==
0
&&
threadIdx
.
x
==
0
)
{
// Wait for the group to have completed reading from shared memory.
cuda
::
ptx
::
cp_async_bulk_wait_group_read
(
cuda
::
ptx
::
n32_t
<
0
>
());
}
__syncthreads
();
if
(
threadIdx
.
x
%
2
==
0
)
{
copy
(
tiled_copy_r2s
,
shared_sf_fragment
,
thr_tile_r2s_shared_sf
);
}
__syncthreads
();
}
// Wait for shared memory writes to be visible to TMA engine.
cuda
::
ptx
::
fence_proxy_async
(
cuda
::
ptx
::
space_shared
);
// b)
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
cuda
::
ptx
::
cp_async_bulk
(
cuda
::
ptx
::
space_global
,
cuda
::
ptx
::
space_shared
,
squeeze_tiled_tensor_sf
.
data
().
get
(),
squeeze_tiled_tensor_shared_sf
.
data
().
get
(),
512
);
// Wait for TMA transfer to have finished reading shared memory.
// Create a "bulk async-group" out of the previous bulk copy operation.
cuda
::
ptx
::
cp_async_bulk_commit_group
();
}
__syncthreads
();
}
template
<
typename
T_IN
,
typename
TiledCopyG2R
,
typename
TiledCopyR2G
,
typename
TiledCopyR2S
>
__global__
void
mxfp8_experts_quant_kernel
(
const
T_IN
*
input
,
const
int
*
problem_sizes
,
const
int
*
expert_offsets
,
const
int
*
blockscale_offsets
,
cutlass
::
float_e4m3_t
*
quant_output
,
uint8_t
*
scale_factor
,
int
groups
,
TiledCopyG2R
tiled_copy_g2r
,
TiledCopyR2G
tiled_copy_r2g
,
TiledCopyR2S
tiled_copy_r2s
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
__shared__
__align__
(
512
)
uint8_t
shared_memory
[
512
];
ScaleFactorTileLayout
scale_factor_tile_layout
{};
auto
scale_factor_shared
=
make_tensor
(
make_smem_ptr
(
shared_memory
),
scale_factor_tile_layout
);
// ((_32,_4), _4):((_16,_4), _1)
// TODO: Transform Groupwise Schedule into a more efficient Schedule
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
int
m
=
problem_sizes
[
g
*
3
+
0
];
int
k
=
problem_sizes
[
g
*
3
+
2
];
int64_t
expert_offset
=
static_cast
<
int64_t
>
(
expert_offsets
[
g
]);
int64_t
blockscale_offset
=
static_cast
<
int64_t
>
(
blockscale_offsets
[
g
]);
auto
input_tensor
=
make_tensor
(
make_gmem_ptr
(
input
+
expert_offset
*
k
),
make_layout
(
make_shape
(
m
,
k
),
LayoutRight
{}));
// (M, K):(K, 1) half_t/bfloat16_t
auto
quant_output_tensor
=
make_tensor
(
make_gmem_ptr
(
quant_output
+
expert_offset
*
k
),
make_layout
(
make_shape
(
m
,
k
),
LayoutRight
{}));
// (M, K):(K, 1) cutlass::float_e4m3_t
auto
scale_factor_shape
=
make_shape
(
ceil_div
(
m
,
128
)
*
128
,
k
/
32
);
auto
scale_factor_layout
=
tile_to_shape
(
scale_factor_tile_layout
,
scale_factor_shape
,
LayoutRight
{});
// layout<0>(layout<0>(scale_factor_layout)) (_32,_4):(_16,_4) -- static
// layout<1>(layout<0>(scale_factor_layout)) M_align_128 / 128 -- dynamic
// shape dynamic stride layout<0>(layout<1>(scale_factor_layout)) _4:_1 --
// static layout<1>(layout<1>(scale_factor_layout)) (K / 32) / 4 : _512 --
// dynamic shape static stride
// Reshape to zipped layout for 1D indexing
auto
zipped_scale_factor_layout
=
make_layout
(
make_layout
(
layout
<
0
>
(
layout
<
0
>
(
scale_factor_layout
)),
layout
<
0
>
(
layout
<
1
>
(
scale_factor_layout
))),
make_layout
(
layout
<
1
>
(
layout
<
0
>
(
scale_factor_layout
)),
layout
<
1
>
(
layout
<
1
>
(
scale_factor_layout
))));
// (((_32,_4),_4),(M_align_128 /
// 128,(K / 32) /
// 4)):(((_16,_4),_1),(?,_512))
auto
scale_factor_tensor
=
make_tensor
(
make_gmem_ptr
(
scale_factor
+
blockscale_offset
*
(
k
/
32
)),
zipped_scale_factor_layout
);
// Used for cases where M is not divisible by 128 (most scenarios).
auto
input_shape
=
shape
(
input_tensor
);
// (M, K):(K, 1)
auto
identity_tensor
=
make_identity_tensor
(
input_shape
);
auto
predict_tensor
=
cute
::
lazy
::
transform
(
identity_tensor
,
[
&
](
auto
c
)
{
return
elem_less
(
c
,
input_shape
);
});
// (_128, _128)
auto
tiler
=
make_shape
(
Int
<
BLOCK_M
>
{},
Int
<
BLOCK_K
>
{});
auto
tiled_input_tensor
=
zipped_divide
(
input_tensor
,
tiler
);
// ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
auto
tiled_quant_output_tensor
=
zipped_divide
(
quant_output_tensor
,
tiler
);
// ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
auto
tiled_predict_tensor
=
zipped_divide
(
predict_tensor
,
tiler
);
// ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
auto
total_tiles
=
size
<
1
>
(
tiled_input_tensor
);
// cdiv(M, 128) * cdiv(K, 128)
decltype
(
total_tiles
)
blk_offset
=
blockIdx
.
x
;
while
(
blk_offset
<
total_tiles
)
{
auto
current_input_tile
=
tensor
<
0
>
(
tiled_input_tensor
(
_
,
blk_offset
));
auto
current_quant_output_tile
=
tensor
<
0
>
(
tiled_quant_output_tensor
(
_
,
blk_offset
));
auto
current_predict_tile
=
tensor
<
0
>
(
tiled_predict_tensor
(
_
,
blk_offset
));
auto
current_scale_factor_tile
=
tensor
<
0
>
(
scale_factor_tensor
(
_
,
blk_offset
));
mxfp8_experts_quant_tile
<
decltype
(
current_input_tile
),
decltype
(
current_predict_tile
),
decltype
(
current_quant_output_tile
),
decltype
(
scale_factor_shared
),
decltype
(
current_scale_factor_tile
),
TiledCopyG2R
,
TiledCopyR2G
,
TiledCopyR2S
>
(
current_input_tile
,
current_predict_tile
,
current_quant_output_tile
,
scale_factor_shared
,
current_scale_factor_tile
,
m
,
tiled_copy_g2r
,
tiled_copy_r2g
,
tiled_copy_r2s
);
blk_offset
+=
gridDim
.
x
;
}
}
#endif
}
template
<
typename
T_IN
>
void
launch_mxfp8_experts_quant
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
blockscale_offsets
,
torch
::
Tensor
&
quant_output
,
torch
::
Tensor
&
scale_factor
)
{
ThrLayout
thr_layout
{};
ValLayout
val_layout
{};
SfR2SThrLayout
r2s_thr_layout
{};
SfR2SValLayout
r2s_val_layout
{};
using
CopyOpG2R
=
UniversalCopy
<
cutlass
::
AlignedArray
<
T_IN
,
size
(
val_layout
)
>>
;
using
CopyAtomG2R
=
cute
::
Copy_Atom
<
CopyOpG2R
,
T_IN
>
;
auto
tiled_copy_g2r
=
cute
::
make_tiled_copy
(
CopyAtomG2R
{},
thr_layout
,
val_layout
);
// Tiler_MN: (16, 128)
using
CopyOpR2G
=
UniversalCopy
<
cutlass
::
AlignedArray
<
cutlass
::
float_e4m3_t
,
size
(
val_layout
)
>>
;
using
CopyAtomR2G
=
cute
::
Copy_Atom
<
CopyOpR2G
,
cutlass
::
float_e4m3_t
>
;
auto
tiled_copy_r2g
=
cute
::
make_tiled_copy
(
CopyAtomR2G
{},
thr_layout
,
val_layout
);
// Tiler_MN: (16, 128)
using
CopyOpR2S
=
UniversalCopy
<
cutlass
::
AlignedArray
<
uint8_t
,
size
(
r2s_val_layout
)
>>
;
using
CopyAtomR2S
=
cute
::
Copy_Atom
<
CopyOpR2S
,
uint8_t
>
;
auto
tiled_copy_r2s
=
cute
::
make_tiled_copy
(
CopyAtomR2S
{},
r2s_thr_layout
,
r2s_val_layout
);
// Tiler_MN: (16, 4)
int
max_active_blocks_per_sm
=
-
1
;
AT_CUDA_CHECK
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks_per_sm
,
mxfp8_experts_quant_kernel
<
T_IN
,
decltype
(
tiled_copy_g2r
),
decltype
(
tiled_copy_r2g
),
decltype
(
tiled_copy_r2s
)
>
,
THREAD_BLOCK_SIZE
,
0
));
dim3
grid
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
max_active_blocks_per_sm
,
1
,
1
);
dim3
block
(
THREAD_BLOCK_SIZE
,
1
,
1
);
int
num_experts
=
(
int
)
problem_sizes
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
mxfp8_experts_quant_kernel
<
T_IN
,
decltype
(
tiled_copy_g2r
),
decltype
(
tiled_copy_r2g
),
decltype
(
tiled_copy_r2s
)
>
<<<
grid
,
block
,
0
,
stream
>>>
(
reinterpret_cast
<
const
T_IN
*>
(
input
.
data_ptr
()),
reinterpret_cast
<
const
int
*>
(
problem_sizes
.
data_ptr
()),
reinterpret_cast
<
const
int
*>
(
expert_offsets
.
data_ptr
()),
reinterpret_cast
<
const
int
*>
(
blockscale_offsets
.
data_ptr
()),
reinterpret_cast
<
cutlass
::
float_e4m3_t
*>
(
quant_output
.
data_ptr
()),
reinterpret_cast
<
uint8_t
*>
(
scale_factor
.
data_ptr
()),
num_experts
,
tiled_copy_g2r
,
tiled_copy_r2g
,
tiled_copy_r2s
);
}
}
// namespace expert_specialization
\ No newline at end of file
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
View file @
3fb4b5fa
...
...
@@ -57,7 +57,7 @@ void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows,
template
<
typename
T
>
void
expandInputRowsKernelLauncher
(
T
const
*
unpermuted_input
,
T
*
permuted_output
,
int
*
sorted_experts
,
T
const
*
unpermuted_input
,
T
*
permuted_output
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int
*
permuted_idx
,
int64_t
const
*
expert_first_token_offset
,
int64_t
const
num_rows
,
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
View file @
3fb4b5fa
...
...
@@ -2,7 +2,7 @@
template <typename T, bool CHECK_SKIPPED>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output,
int* sorted_experts,
T const* unpermuted_input, T* permuted_output,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset, int64_t const num_rows,
...
...
@@ -16,7 +16,6 @@ __global__ void expandInputRowsKernel(
int64_t expanded_dest_row = blockIdx.x;
int64_t const expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
...
...
@@ -54,7 +53,7 @@ __global__ void expandInputRowsKernel(
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
int* sorted_experts,
T const* unpermuted_input, T* permuted_output,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset, int64_t const num_rows,
...
...
@@ -70,12 +69,12 @@ void expandInputRowsKernelLauncher(
bool is_check_skip = num_valid_tokens_ptr != nullptr;
auto func = func_map[is_check_skip];
func<<<blocks, threads, 0, stream>>>(
unpermuted_input, permut
ed_ou
tput, sorted_experts
,
expanded_
dest
_row_to_expanded_
source
_row,
expanded_source_row_to_expanded_dest_row, permuted_idx
,
expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k,
num_local_experts);
func<<<blocks, threads, 0, stream>>>(
unpermuted_input, permuted_output,
expanded_dest_row_to_expand
ed_
s
ou
rce_row
,
expanded_
source
_row_to_expanded_
dest
_row,
permuted_idx, expert_first_token_offset
,
num_rows, num_valid_tokens_ptr, cols, k,
num_local_experts);
}
template <class T, class U>
...
...
csrc/moe/router_gemm.cu
0 → 100644
View file @
3fb4b5fa
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// bf16 x bf16 -> fp32 router GEMM via cuBLAS.
// Uses CUBLAS_COMPUTE_32F so bf16 operands accumulate into fp32,
// matching TRT-LLM's cuBLAS fallback behaviour in dsv3RouterGemmOp.
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <cublas_v2.h>
// cuBLAS column-major math for row-major PyTorch tensors:
// weight[N,K]_row lda=K -> cuBLAS sees (K,N) col-major; CUBLAS_OP_T ->
// (N,K) input[M,K]_row ldb=K -> cuBLAS sees (K,M) col-major; CUBLAS_OP_N
// -> (K,M) out[M,N]_row ldc=N -> cuBLAS sees (N,M) col-major (written as
// output^T)
// cuBLAS: C(N,M) = weight(N,K) @ input(K,M) => C^T = output[M,N]
// params: m=N, n=M, k=K, lda=K (weight), ldb=K (input), ldc=N (output)
torch
::
Tensor
router_gemm_bf16_fp32
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
weight
)
{
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kBFloat16
,
"router_gemm_bf16_fp32: input must be bfloat16"
);
TORCH_CHECK
(
weight
.
dtype
()
==
torch
::
kBFloat16
,
"router_gemm_bf16_fp32: weight must be bfloat16"
);
TORCH_CHECK
(
input
.
dim
()
==
2
&&
weight
.
dim
()
==
2
,
"router_gemm_bf16_fp32: input and weight must be 2-D"
);
TORCH_CHECK
(
input
.
size
(
1
)
==
weight
.
size
(
1
),
"router_gemm_bf16_fp32: inner dimensions must match"
);
int64_t
const
M
=
input
.
size
(
0
);
int64_t
const
N
=
weight
.
size
(
0
);
int64_t
const
K
=
input
.
size
(
1
);
auto
out
=
torch
::
empty
({
M
,
N
},
input
.
options
().
dtype
(
torch
::
kFloat32
));
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
TORCH_CUDABLAS_CHECK
(
cublasSetStream
(
handle
,
at
::
cuda
::
getCurrentCUDAStream
()));
float
const
alpha
=
1.0
f
;
float
const
beta
=
0.0
f
;
TORCH_CUDABLAS_CHECK
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
static_cast
<
int
>
(
N
),
static_cast
<
int
>
(
M
),
static_cast
<
int
>
(
K
),
&
alpha
,
weight
.
data_ptr
(),
CUDA_R_16BF
,
static_cast
<
int
>
(
K
),
input
.
data_ptr
(),
CUDA_R_16BF
,
static_cast
<
int
>
(
K
),
&
beta
,
out
.
data_ptr
(),
CUDA_R_32F
,
static_cast
<
int
>
(
N
),
CUBLAS_COMPUTE_32F
,
CUBLAS_GEMM_DEFAULT
));
return
out
;
}
csrc/moe/torch_bindings.cpp
View file @
3fb4b5fa
...
...
@@ -124,6 +124,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, "
"Tensor)"
);
m
.
impl
(
"grouped_topk"
,
torch
::
kCUDA
,
&
grouped_topk
);
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
m
.
def
(
"router_gemm_bf16_fp32(Tensor input, Tensor weight) -> Tensor"
);
m
.
impl
(
"router_gemm_bf16_fp32"
,
torch
::
kCUDA
,
&
router_gemm_bf16_fp32
);
// DeepSeek V3 optimized router GEMM for SM90+
m
.
def
(
"dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"
);
// conditionally compiled so impl registration is in source file
#endif
}
...
...
csrc/ops.h
View file @
3fb4b5fa
...
...
@@ -114,6 +114,10 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
int64_t
numRows
,
int64_t
stride0
,
int64_t
stride1
,
int64_t
topK
);
void
large_context_topk
(
const
torch
::
Tensor
&
score
,
torch
::
Tensor
&
indices
,
const
torch
::
Tensor
&
lengths
,
std
::
optional
<
torch
::
Tensor
>
row_starts_opt
);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
...
...
@@ -265,13 +269,13 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
n
,
const
int64_t
k
,
const
bool
swap_ab
);
void
get_cutlass_
pplx
_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
torch
::
Tensor
&
expert_num_tokens
,
const
int64_t
num_local_experts
,
const
int64_t
padded_m
,
const
int64_t
n
,
const
int64_t
k
);
void
get_cutlass_
batched
_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
torch
::
Tensor
&
expert_num_tokens
,
const
int64_t
num_local_experts
,
const
int64_t
padded_m
,
const
int64_t
n
,
const
int64_t
k
);
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
...
...
@@ -291,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
std
::
vector
<
torch
::
Tensor
>
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
,
bool
is_sf_swizzled_layout
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
scaled_fp4_quant_func
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_scale
,
bool
is_sf_swizzled_layout
);
void
scaled_fp4_quant_out
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_scale
,
bool
is_sf_swizzled_layout
,
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
);
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
...
...
@@ -311,7 +319,9 @@ void silu_and_mul_scaled_fp4_experts_quant(
void
per_token_group_quant_fp8
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
,
bool
scale_ue8m0
);
double
fp8_max
,
bool
scale_ue8m0
,
bool
dummy_is_scale_transposed
,
bool
dummy_is_tma_aligned
);
void
per_token_group_quant_int8
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
...
...
@@ -365,7 +375,9 @@ void selective_scan_fwd(
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
,
int64_t
block_size
,
const
std
::
optional
<
torch
::
Tensor
>&
block_idx_first_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>&
block_idx_last_scheduled_token
,
const
std
::
optional
<
torch
::
Tensor
>&
initial_state_idx
);
const
std
::
optional
<
torch
::
Tensor
>&
initial_state_idx
,
const
std
::
optional
<
torch
::
Tensor
>&
cu_chunk_seqlen
,
const
std
::
optional
<
torch
::
Tensor
>&
last_chunk_indices
);
torch
::
Tensor
dynamic_4bit_int_moe_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
topk_ids
,
torch
::
Tensor
topk_weights
,
...
...
@@ -404,3 +416,8 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t
quant_level
,
bool
cast_bf2half
=
false
);
int64_t
qr_max_size
();
#endif
#ifndef USE_ROCM
void
dsv3_fused_a_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
mat_a
,
torch
::
Tensor
const
&
mat_b
);
#endif
\ No newline at end of file
csrc/quantization/activation_kernels.cu
View file @
3fb4b5fa
...
...
@@ -542,7 +542,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
if
(
!
lane_id
)
{
// Store scales.
if
constexpr
(
std
::
is_same
<
scale_t
,
uint8_t
>::
value
)
{
// Packed UE8M
O
format. Remove Mantissa.
// Packed UE8M
0
format. Remove Mantissa.
*
y_s_ptr
=
reinterpret_cast
<
int16_t
&>
(
y_s
)
>>
7
;
bool
const
jump_pack
=
(
current_group_id
+
1
)
%
4
==
0
;
...
...
csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
View file @
3fb4b5fa
...
...
@@ -39,12 +39,12 @@ namespace vllm {
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
__launch_bounds__
(
512
,
VLLM_BLOCKS_PER_SM
(
512
))
silu_mul_cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
int32_t
num_pa
dd
ed_cols
,
int32_t
num_pa
ck
ed_cols
,
Type
const
*
__restrict__
in
,
float
const
*
__restrict__
SFScale
,
uint32_t
*
__restrict__
out
,
uint32_t
*
__restrict__
SFout
)
{
using
PackedVec
=
vllm
::
PackedVec
<
Type
>
;
using
PackedVec
=
vllm
::
PackedVec
<
Type
,
CVT_FP4_PACK16
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
...
...
@@ -63,7 +63,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
if
(
colIdx
<
num_pa
dd
ed_cols
)
{
if
(
colIdx
<
num_pa
ck
ed_cols
)
{
PackedVec
in_vec
;
PackedVec
in_vec2
;
int64_t
inOffset
=
...
...
@@ -73,19 +73,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
bool
valid
=
(
rowIdx
<
numRows
)
&&
(
elem_idx
<
numCols
);
if
constexpr
(
CVT_FP4_PACK16
)
{
ld256_or_zero
_cg_u32
<
Type
>
(
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
8
],
valid
);
ld256_or_zero
_cg_u32
<
Type
>
(
in_vec2
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset2
*
8
],
valid
);
ld256_
cg_
or_zero
(
reinterpret_cast
<
u32x8_t
&>
(
in_vec
),
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
8
],
valid
);
ld256_
cg_
or_zero
(
reinterpret_cast
<
u32x8_t
&>
(
in_vec2
),
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset2
*
8
],
valid
);
}
else
{
ld128_or_zero
_cg_u32
<
Type
>
(
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
4
],
valid
);
ld128_or_zero
_cg_u32
<
Type
>
(
in_vec2
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset2
*
4
],
valid
);
ld128_
cg_
or_zero
(
reinterpret_cast
<
uint4
&>
(
in_vec
),
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
4
],
valid
);
ld128_
cg_
or_zero
(
reinterpret_cast
<
uint4
&>
(
in_vec2
),
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset2
*
4
],
valid
);
}
// Compute silu and mul
...
...
@@ -107,7 +107,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
(
uint64_t
(
out_val
.
hi
)
<<
32
)
|
uint64_t
(
out_val
.
lo
);
reinterpret_cast
<
uint64_t
*>
(
out
)[
outOffset
>>
1
]
=
packed64
;
}
else
{
out
[
inOffset
]
=
out_val
;
int64_t
outOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
out
[
outOffset
]
=
out_val
;
}
}
}
...
...
@@ -140,9 +142,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
int
const
numBlocksPerSM
=
vllm_runtime_blocks_per_sm
(
static_cast
<
int
>
(
block
.
x
));
int
sf_n_unpadded
=
int
(
n
/
CVT_FP4_
SF_VEC_SIZE
);
int
num_packed_cols
=
int
(
n
/
CVT_FP4_
ELTS_PER_THREAD
);
int
grid_y
=
vllm
::
div_round_up
(
sf_n_unpadded
,
static_cast
<
int
>
(
block
.
x
));
int
grid_y
=
vllm
::
div_round_up
(
num_packed_cols
,
static_cast
<
int
>
(
block
.
x
));
int
grid_x
=
std
::
min
(
int
(
m
),
std
::
max
(
1
,
(
multiProcessorCount
*
numBlocksPerSM
)
/
grid_y
));
dim3
grid
(
grid_x
,
grid_y
);
...
...
@@ -152,7 +154,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
auto
input_ptr
=
static_cast
<
cuda_type
const
*>
(
input
.
data_ptr
());
vllm
::
silu_mul_cvt_fp16_to_fp4
<
cuda_type
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
sf_n_unpadded
,
input_ptr
,
input_sf_ptr
,
m
,
n
,
num_packed_cols
,
input_ptr
,
input_sf_ptr
,
reinterpret_cast
<
uint32_t
*>
(
output_ptr
),
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
});
...
...
csrc/quantization/fp4/nvfp4_experts_quant.cu
View file @
3fb4b5fa
...
...
@@ -43,7 +43,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
,
bool
low_latency
)
{
using
PackedVec
=
PackedVec
<
Type
>
;
using
PackedVec
=
PackedVec
<
Type
,
CVT_FP4_PACK16
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
...
...
@@ -155,7 +155,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
using
PackedVec
=
PackedVec
<
Type
>
;
using
PackedVec
=
PackedVec
<
Type
,
CVT_FP4_PACK16
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
...
...
csrc/quantization/fp4/nvfp4_quant_entry.cu
View file @
3fb4b5fa
...
...
@@ -16,6 +16,8 @@
#include <torch/all.h>
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void
scaled_fp4_quant_sm1xxa
(
torch
::
Tensor
const
&
output
,
...
...
@@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
,
bool
is_sf_swizzled_layout
)
{
void
scaled_fp4_quant_out
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_sf
,
bool
is_sf_swizzled_layout
,
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_sf
)
{
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return
scaled_fp4_quant_sm1xxa
(
output
,
input
,
output_sf
,
input_sf
,
...
...
@@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization kernel"
);
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
scaled_fp4_quant_func
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_sf
,
bool
is_sf_swizzled_layout
)
{
int64_t
n
=
input
.
size
(
-
1
);
int64_t
m
=
input
.
numel
()
/
n
;
auto
device
=
input
.
device
();
// Two fp4 values packed into a uint8
auto
output
=
torch
::
empty
(
{
m
,
n
/
2
},
torch
::
TensorOptions
().
device
(
device
).
dtype
(
torch
::
kUInt8
));
torch
::
Tensor
output_sf
;
if
(
is_sf_swizzled_layout
)
{
auto
[
sf_m
,
sf_n
]
=
vllm
::
computeSwizzledSFShape
(
m
,
n
);
output_sf
=
torch
::
empty
(
{
sf_m
,
sf_n
},
torch
::
TensorOptions
().
device
(
device
).
dtype
(
torch
::
kInt32
));
}
else
{
output_sf
=
torch
::
empty
(
{
m
,
n
/
CVT_FP4_SF_VEC_SIZE
},
torch
::
TensorOptions
().
device
(
device
).
dtype
(
torch
::
kUInt8
));
}
scaled_fp4_quant_out
(
input
,
input_sf
,
is_sf_swizzled_layout
,
output
,
output_sf
);
return
{
output
,
output_sf
};
}
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
...
...
csrc/quantization/fp4/nvfp4_quant_kernels.cu
View file @
3fb4b5fa
...
...
@@ -42,7 +42,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
Type
const
*
__restrict__
in
,
float
const
*
__restrict__
SFScale
,
uint32_t
*
__restrict__
out
,
uint32_t
*
__restrict__
SFout
)
{
using
PackedVec
=
vllm
::
PackedVec
<
Type
>
;
using
PackedVec
=
vllm
::
PackedVec
<
Type
,
CVT_FP4_PACK16
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
...
...
@@ -71,13 +71,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// If we are outside valid rows OR outside valid columns -> Use Zeros
bool
valid
=
(
rowIdx
<
numRows
)
&&
(
elem_idx
<
numCols
);
if
constexpr
(
CVT_FP4_PACK16
)
{
ld256_or_zero
_cg_u32
<
Type
>
(
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
8
],
valid
);
ld256_
cg_
or_zero
(
reinterpret_cast
<
u32x8_t
&>
(
in_vec
),
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
8
],
valid
);
}
else
{
ld128_or_zero
_cg_u32
<
Type
>
(
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
4
],
valid
);
ld128_
cg_
or_zero
(
reinterpret_cast
<
uint4
&>
(
in_vec
),
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
4
],
valid
);
}
auto
sf_out
=
...
...
@@ -109,11 +109,12 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
__launch_bounds__
(
512
,
VLLM_BLOCKS_PER_SM
(
512
))
cvt_fp16_to_fp4_sf_major
(
int32_t
numRows
,
int32_t
numCols
,
int32_t
sf_n_unpadded
,
Type
const
*
__restrict__
in
,
int32_t
sf_n_unpadded
,
int32_t
num_packed_cols
,
Type
const
*
__restrict__
in
,
float
const
*
__restrict__
SFScale
,
uint32_t
*
__restrict__
out
,
uint32_t
*
__restrict__
SFout
)
{
using
PackedVec
=
PackedVec
<
Type
>
;
using
PackedVec
=
PackedVec
<
Type
,
CVT_FP4_PACK16
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
...
...
@@ -131,20 +132,20 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
if
(
colIdx
<
sf_n_unpadded
)
{
if
(
colIdx
<
num_packed_cols
)
{
PackedVec
in_vec
;
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
// If we are outside valid rows OR outside valid columns -> Use Zeros
bool
valid
=
(
rowIdx
<
numRows
)
&&
(
elem_idx
<
numCols
);
if
constexpr
(
CVT_FP4_PACK16
)
{
ld256_or_zero
_cg_u32
<
Type
>
(
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
8
],
valid
);
ld256_
cg_
or_zero
(
reinterpret_cast
<
u32x8_t
&>
(
in_vec
),
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
8
],
valid
);
}
else
{
ld128_or_zero
_cg_u32
<
Type
>
(
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
4
],
valid
);
ld128_
cg_
or_zero
(
reinterpret_cast
<
uint4
&>
(
in_vec
),
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
4
],
valid
);
}
auto
sf_out
=
...
...
@@ -222,7 +223,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
});
}
else
{
int
grid_y
=
vllm
::
div_round_up
(
sf_n_unpadded
,
static_cast
<
int
>
(
block
.
x
));
int
num_packed_cols
=
n
/
CVT_FP4_ELTS_PER_THREAD
;
int
grid_y
=
vllm
::
div_round_up
(
num_packed_cols
,
static_cast
<
int
>
(
block
.
x
));
int
grid_x
=
std
::
min
(
m
,
std
::
max
(
1
,
(
multiProcessorCount
*
numBlocksPerSM
)
/
grid_y
));
dim3
grid
(
grid_x
,
grid_y
);
...
...
@@ -232,8 +234,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
auto
input_ptr
=
static_cast
<
cuda_type
const
*>
(
input
.
data_ptr
());
// NOTE: We don't support e8m0 scales at this moment.
vllm
::
cvt_fp16_to_fp4_sf_major
<
cuda_type
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
sf_n_unpadded
,
input_ptr
,
input_sf_ptr
,
<<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
sf_n_unpadded
,
num_packed_cols
,
input_ptr
,
input_sf_ptr
,
reinterpret_cast
<
uint32_t
*>
(
output_ptr
),
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
});
...
...
csrc/quantization/fp4/nvfp4_utils.cuh
View file @
3fb4b5fa
...
...
@@ -18,9 +18,12 @@
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <utility>
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
#include "../../cuda_vec_utils.cuh"
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090
#define ELTS_PER_THREAD 16
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
16
;
constexpr
bool
CVT_FP4_PACK16
=
true
;
...
...
@@ -34,68 +37,6 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16;
namespace
vllm
{
// Convert PyTorch cpp type to CUDA type
template
<
typename
T
>
struct
CUDATypeConverter
{
using
Type
=
T
;
};
template
<
>
struct
CUDATypeConverter
<
at
::
Half
>
{
using
Type
=
half
;
};
template
<
>
struct
CUDATypeConverter
<
at
::
BFloat16
>
{
using
Type
=
__nv_bfloat16
;
};
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
// Define a 32 bytes packed data type.
template
<
class
Type
>
struct
alignas
(
32
)
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
8
];
};
#else
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
alignas
(
16
)
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
#endif
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
template
<
typename
Int
>
__host__
__device__
inline
Int
round_up
(
Int
x
,
Int
y
)
{
static_assert
(
std
::
is_integral_v
<
Int
>
,
...
...
@@ -114,6 +55,18 @@ inline int computeEffectiveRows(int m) {
return
round_up
(
m
,
ROW_TILE
);
}
// Compute the shape of the swizzled SF output tensor.
// Returns (rounded_m, rounded_n / 4) where:
// rounded_m = round_up(m, 128)
// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4)
inline
std
::
pair
<
int64_t
,
int64_t
>
computeSwizzledSFShape
(
int64_t
m
,
int64_t
n
)
{
int64_t
rounded_m
=
round_up
(
m
,
static_cast
<
int64_t
>
(
128
));
int64_t
scale_n
=
n
/
CVT_FP4_SF_VEC_SIZE
;
int64_t
rounded_n
=
round_up
(
scale_n
,
static_cast
<
int64_t
>
(
4
));
return
{
rounded_m
,
rounded_n
/
4
};
}
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec8_to_e2m1
(
float
(
&
array
)[
8
])
{
uint32_t
val
;
...
...
@@ -208,56 +161,6 @@ __device__ __forceinline__ float reciprocal_approximate_ftz(float a) {
return
b
;
}
template
<
class
Type
>
__device__
__forceinline__
void
ld128_or_zero_cg_u32
(
PackedVec
<
Type
>&
out
,
const
void
*
ptr
,
bool
pred
)
{
uint32_t
r0
,
r1
,
r2
,
r3
;
asm
volatile
(
"{
\n
"
" .reg .pred pr;
\n
"
" setp.ne.u32 pr, %4, 0;
\n
"
" mov.u32 %0, 0;
\n
"
" mov.u32 %1, 0;
\n
"
" mov.u32 %2, 0;
\n
"
" mov.u32 %3, 0;
\n
"
" @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];
\n
"
"}
\n
"
:
"=r"
(
r0
),
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
)
:
"r"
((
int
)
pred
),
"l"
(
ptr
));
*
reinterpret_cast
<
uint4
*>
(
&
out
)
=
uint4
{
r0
,
r1
,
r2
,
r3
};
}
template
<
class
Type
>
__device__
__forceinline__
void
ld256_or_zero_cg_u32
(
PackedVec
<
Type
>&
out
,
const
void
*
ptr
,
bool
pred
)
{
uint32_t
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
;
asm
volatile
(
"{
\n
"
" .reg .pred pr;
\n
"
" setp.ne.u32 pr, %8, 0;
\n
"
" mov.u32 %0, 0;
\n
"
" mov.u32 %1, 0;
\n
"
" mov.u32 %2, 0;
\n
"
" mov.u32 %3, 0;
\n
"
" mov.u32 %4, 0;
\n
"
" mov.u32 %5, 0;
\n
"
" mov.u32 %6, 0;
\n
"
" mov.u32 %7, 0;
\n
"
" @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];
\n
"
"}
\n
"
:
"=r"
(
r0
),
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
),
"=r"
(
r5
),
"=r"
(
r6
),
"=r"
(
r7
)
:
"r"
((
int
)
pred
),
"l"
(
ptr
));
reinterpret_cast
<
uint4
*>
(
&
out
)[
0
]
=
uint4
{
r0
,
r1
,
r2
,
r3
};
reinterpret_cast
<
uint4
*>
(
&
out
)[
1
]
=
uint4
{
r4
,
r5
,
r6
,
r7
};
}
// Compute SF output offset for swizzled tensor core layout.
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
// Caller must precompute: numKTiles = (numCols + 63) / 64
...
...
@@ -315,8 +218,8 @@ __device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack,
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
int
CVT_FP4_NUM_THREADS_PER_SF
,
bool
UE8M0_SF
=
false
>
__device__
__forceinline__
fp4_packed_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
__device__
__forceinline__
fp4_packed_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
,
CVT_FP4_PACK16
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
...
...
@@ -372,11 +275,7 @@ cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
]
=
cast_to_float2
(
vec
.
elts
[
i
]);
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
...
...
@@ -395,22 +294,19 @@ __device__ __forceinline__ float2 silu2(float2 x) {
}
template
<
class
Type
>
__inline__
__device__
PackedVec
<
Type
>
compute_silu_mul
(
const
PackedVec
<
Type
>&
x_vec
,
const
PackedVec
<
Type
>&
y_vec
)
{
PackedVec
<
Type
>
result
;
__inline__
__device__
PackedVec
<
Type
,
CVT_FP4_PACK16
>
compute_silu_mul
(
const
PackedVec
<
Type
,
CVT_FP4_PACK16
>&
x_vec
,
const
PackedVec
<
Type
,
CVT_FP4_PACK16
>&
y_vec
)
{
PackedVec
<
Type
,
CVT_FP4_PACK16
>
result
;
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
++
i
)
{
// silu_mul in float32
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
float2
silu_vec
=
silu2
(
__half22float2
(
x_vec
.
elts
[
i
]));
result
.
elts
[
i
]
=
__float22half2_rn
(
__fmul2_rn
(
silu_vec
,
__half22float2
(
y_vec
.
elts
[
i
])));
}
else
{
float2
silu_vec
=
silu2
(
__bfloat1622float2
(
x_vec
.
elts
[
i
]));
result
.
elts
[
i
]
=
__float22bfloat162_rn
(
__fmul2_rn
(
silu_vec
,
__bfloat1622float2
(
y_vec
.
elts
[
i
])));
}
using
packed_t
=
typename
PackedTypeConverter
<
Type
>::
Type
;
float2
silu_vec
=
silu2
(
cast_to_float2
(
x_vec
.
elts
[
i
]));
float2
y_f2
=
cast_to_float2
(
y_vec
.
elts
[
i
]);
result
.
elts
[
i
]
=
cast_to_packed
<
packed_t
>
(
make_float2
(
silu_vec
.
x
*
y_f2
.
x
,
silu_vec
.
y
*
y_f2
.
y
));
}
return
result
;
}
...
...
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
View file @
3fb4b5fa
...
...
@@ -29,31 +29,33 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
float
const
*
scale_ub
,
float
const
var_epsilon
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
int32_t
const
input_stride
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
float
rms
=
0.0
f
;
float
token_scale
=
0.0
f
;
// Compute rms
vllm
::
vectorized
::
compute_rms
<
scalar_t
,
has_residual
>
(
&
rms
,
input
,
hidden_size
,
var_epsilon
,
residual
);
&
rms
,
input
,
hidden_size
,
input_stride
,
var_epsilon
,
residual
);
// Compute scale
vllm
::
vectorized
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
&
token_scale
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
hidden_size
,
residual
);
input_stride
,
residual
);
// RMS Norm + Quant
if
constexpr
(
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
)
{
token_scale
=
1.0
f
/
token_scale
;
vllm
::
vectorized
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
true
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
&
token_scale
,
hidden_size
,
residual
);
has_residual
>
(
out
,
input
,
weight
,
rms
,
&
token_scale
,
hidden_size
,
input_stride
,
residual
);
}
else
{
// FP8 - Do not invert token_scale for exact match with FBGemm
vllm
::
vectorized
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
false
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
&
token_scale
,
hidden_size
,
residual
);
has_residual
>
(
out
,
input
,
weight
,
rms
,
&
token_scale
,
hidden_size
,
input_stride
,
residual
);
}
}
...
...
@@ -65,38 +67,40 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
float
const
*
scale_ub
,
float
const
var_epsilon
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
int32_t
const
input_stride
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool
const
can_vectorize
=
hidden_size
%
4
==
0
;
bool
const
can_vectorize
=
hidden_size
%
4
==
0
and
input_stride
%
4
==
0
;
if
(
can_vectorize
)
{
return
rms_norm_dynamic_per_token_quant_vec
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
out
,
scales
,
input
,
weight
,
scale_ub
,
var_epsilon
,
hidden_size
,
residual
);
input_stride
,
residual
);
}
float
rms
=
0.0
f
;
float
token_scale
=
0.0
f
;
// Compute RMS
vllm
::
compute_rms
<
scalar_t
,
has_residual
>
(
&
rms
,
input
,
hidden_size
,
var_epsilon
,
residual
);
vllm
::
compute_rms
<
scalar_t
,
has_residual
>
(
&
rms
,
input
,
hidden_size
,
input_stride
,
var_epsilon
,
residual
);
// Compute Scale
vllm
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
has_residual
>
(
&
token_scale
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
hidden_size
,
residual
);
input_stride
,
residual
);
// RMS Norm + Quant
if
constexpr
(
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
)
{
token_scale
=
1.0
f
/
token_scale
;
vllm
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
true
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
&
token_scale
,
hidden_size
,
residual
);
out
,
input
,
weight
,
rms
,
&
token_scale
,
hidden_size
,
input_stride
,
residual
);
}
else
{
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
false
,
has_residual
>
(
out
,
input
,
weight
,
rms
,
&
token_scale
,
hidden_size
,
residual
);
out
,
input
,
weight
,
rms
,
&
token_scale
,
hidden_size
,
input_stride
,
residual
);
}
}
...
...
@@ -111,18 +115,20 @@ __global__ void rms_norm_per_block_quant_kernel(
scalar_t
const
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
const
*
__restrict__
weight
,
// [hidden_size]
float
const
*
scale_ub
,
float
const
var_epsilon
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
int32_t
const
input_stride
,
scalar_t
*
__restrict__
residual
=
nullptr
,
int64_t
outer_scale_stride
=
1
)
{
float
rms
;
// Compute RMS
// Always able to vectorize due to constraints on hidden_size
vllm
::
vectorized
::
compute_rms
<
scalar_t
,
has_residual
>
(
&
rms
,
input
,
hidden_size
,
var_epsilon
,
residual
);
&
rms
,
input
,
hidden_size
,
input_stride
,
var_epsilon
,
residual
);
// Compute Scale
// Always able to vectorize due to constraints on hidden_size and group_size
vllm
::
vectorized
::
compute_dynamic_per_token_scales
<
scalar_t
,
scalar_out_t
,
has_residual
,
is_scale_transposed
,
group_size
>
(
nullptr
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
hidden_size
,
residual
);
nullptr
,
scales
,
input
,
weight
,
rms
,
scale_ub
,
hidden_size
,
input_stride
,
residual
,
outer_scale_stride
);
// RMS Norm + Quant
// Always able to vectorize due to constraints on hidden_size
...
...
@@ -133,7 +139,8 @@ __global__ void rms_norm_per_block_quant_kernel(
vllm
::
vectorized
::
norm_and_quant
<
scalar_t
,
scalar_out_t
,
std
::
is_same_v
<
scalar_out_t
,
int8_t
>
,
has_residual
,
is_scale_transposed
,
group_size
>
(
out
,
input
,
weight
,
rms
,
scales
,
hidden_size
,
residual
);
out
,
input
,
weight
,
rms
,
scales
,
hidden_size
,
input_stride
,
residual
,
outer_scale_stride
);
}
}
// namespace vllm
...
...
@@ -149,6 +156,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
std
::
optional
<
at
::
Tensor
>
const
&
scale_ub
,
std
::
optional
<
at
::
Tensor
>&
residual
)
{
int32_t
hidden_size
=
input
.
size
(
-
1
);
int32_t
input_stride
=
input
.
view
({
-
1
,
hidden_size
}).
stride
(
0
);
auto
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
...
...
@@ -165,7 +173,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
out
.
data_ptr
<
scalar_t
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_in_t
>
(),
weight
.
data_ptr
<
scalar_in_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
var_epsilon
,
hidden_size
,
var_epsilon
,
hidden_size
,
input_stride
,
has_residual
?
residual
->
data_ptr
<
scalar_in_t
>
()
:
nullptr
);
});
});
...
...
@@ -182,7 +190,9 @@ void rms_norm_dynamic_per_token_quant(
?
c10
::
ScalarType
::
Float8_e4m3fn
:
c10
::
ScalarType
::
Float8_e4m3fnuz
;
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
||
out
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
out
.
is_contiguous
()
&&
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
stride
(
-
1
)
==
1
,
"Input must be contiguous in the last dimension"
);
if
(
scale_ub
.
has_value
())
{
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
);
...
...
@@ -191,6 +201,7 @@ void rms_norm_dynamic_per_token_quant(
TORCH_CHECK
(
scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
residual
)
{
TORCH_CHECK
(
residual
->
scalar_type
()
==
input
.
scalar_type
());
TORCH_CHECK
(
residual
->
is_contiguous
());
}
VLLM_DISPATCH_FLOATING_TYPES
(
...
...
@@ -212,6 +223,15 @@ void rms_norm_per_block_quant_dispatch(
std
::
optional
<
at
::
Tensor
>
const
&
scale_ub
,
std
::
optional
<
at
::
Tensor
>&
residual
,
bool
is_scale_transposed
)
{
int32_t
hidden_size
=
input
.
size
(
-
1
);
int32_t
input_stride
=
input
.
view
({
-
1
,
hidden_size
}).
stride
(
0
);
TORCH_CHECK
(
hidden_size
%
4
==
0
,
"Hidden size must be divisible by 4 for vectorized access"
);
TORCH_CHECK
(
input_stride
%
4
==
0
,
"Input stride must be divisible by 4 for vectorized access"
);
TORCH_CHECK
(
group_size
%
4
==
0
,
"Group size must be divisible by 4 for vectorized access"
);
auto
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
...
...
@@ -237,9 +257,10 @@ void rms_norm_per_block_quant_dispatch(
weight
.
data_ptr
<
scalar_in_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
var_epsilon
,
hidden_size
,
var_epsilon
,
hidden_size
,
input_stride
,
has_residual
?
residual
->
data_ptr
<
scalar_in_t
>
()
:
nullptr
);
:
nullptr
,
scales
.
stride
(
1
));
});
});
});
...
...
@@ -257,7 +278,9 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
?
c10
::
ScalarType
::
Float8_e4m3fn
:
c10
::
ScalarType
::
Float8_e4m3fnuz
;
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
||
out
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
out
.
is_contiguous
()
&&
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
stride
(
-
1
)
==
1
,
"Input must be contiguous in the last dimension"
);
if
(
scale_ub
.
has_value
())
{
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
);
...
...
@@ -266,11 +289,17 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
TORCH_CHECK
(
scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
residual
)
{
TORCH_CHECK
(
residual
->
scalar_type
()
==
input
.
scalar_type
());
TORCH_CHECK
(
residual
->
is_contiguous
());
}
TORCH_CHECK
(
group_size
==
128
||
group_size
==
64
,
"Unsupported group size: "
,
group_size
);
if
(
scales
.
stride
(
1
)
>
1
)
{
TORCH_CHECK
(
is_scale_transposed
,
"Outer scale stride must be 1 when scales are not transposed"
);
}
rms_norm_per_block_quant_dispatch
(
out
,
input
,
weight
,
scales
,
group_size
,
var_epsilon
,
scale_ub
,
residual
,
is_scale_transposed
);
...
...
csrc/quantization/fused_kernels/layernorm_utils.cuh
View file @
3fb4b5fa
...
...
@@ -16,14 +16,17 @@ namespace vllm {
// has_residual must be true, if residual is not a nullptr
template
<
typename
scalar_t
,
bool
has_residual
=
false
>
__device__
void
compute_rms
(
float
*
rms
,
scalar_t
const
*
__restrict__
input
,
int32_t
const
hidden_size
,
float
const
epsilon
,
int32_t
const
hidden_size
,
int32_t
const
input_stride
,
float
const
epsilon
,
scalar_t
const
*
__restrict__
residual
=
nullptr
)
{
int64_t
const
input_token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
input_stride
);
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
// sum of squares
float
ss
=
0.0
f
;
for
(
auto
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
x
=
static_cast
<
float
>
(
input
[
token_offset
+
i
]);
float
x
=
static_cast
<
float
>
(
input
[
input_
token_offset
+
i
]);
if
constexpr
(
has_residual
)
{
x
+=
static_cast
<
float
>
(
residual
[
token_offset
+
i
]);
}
...
...
@@ -73,15 +76,20 @@ __device__ void compute_dynamic_per_token_scales(
float
*
__restrict__
token_scale
,
float
*
__restrict__
all_token_scales
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
weight
,
float
const
rms
,
float
const
*
__restrict__
scale_ub
,
int32_t
const
hidden_size
,
scalar_t
const
*
__restrict__
residual
=
nullptr
,
int32_t
const
group_size
=
0
)
{
int32_t
const
hidden_size
,
int32_t
const
input_stride
,
scalar_t
const
*
__restrict__
residual
=
nullptr
,
int32_t
const
group_size
=
0
,
int64_t
outer_scale_stride
=
1
)
{
float
block_absmax_val_maybe
=
0.0
f
;
constexpr
scalar_out_t
qmax
{
quant_type_max_v
<
scalar_out_t
>
};
__syncthreads
();
int64_t
const
input_token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
input_stride
);
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
if
(
group_size
>
0
)
{
__shared__
float
s_max_vals
[
1024
];
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
int64_t
num_groups
=
hidden_size
/
group_size
;
__shared__
float
s_max_vals
[
1024
];
int64_t
const
threads_per_group
=
blockDim
.
x
/
num_groups
;
int64_t
const
thread_in_group
=
threadIdx
.
x
%
threads_per_group
;
int64_t
const
group_offset
=
threadIdx
.
x
/
threads_per_group
*
group_size
;
...
...
@@ -89,7 +97,7 @@ __device__ void compute_dynamic_per_token_scales(
int64_t
const
thread_end
=
min
(
group_offset
+
group_size
,
static_cast
<
int64_t
>
(
hidden_size
));
for
(
auto
i
=
thread_offset
;
i
<
thread_end
;
i
+=
threads_per_group
)
{
float
x
=
static_cast
<
float
>
(
input
[
token_offset
+
i
]);
float
x
=
static_cast
<
float
>
(
input
[
input_
token_offset
+
i
]);
if
constexpr
(
has_residual
)
{
x
+=
static_cast
<
float
>
(
residual
[
token_offset
+
i
]);
}
...
...
@@ -133,7 +141,9 @@ __device__ void compute_dynamic_per_token_scales(
scale
=
max
(
scale
/
qmax
,
min_scaling_factor
<
scalar_out_t
>::
val
());
// Global output store
if
constexpr
(
is_scale_transposed
)
{
all_token_scales
[(
threadIdx
.
x
/
threads_per_group
)
*
gridDim
.
x
+
int64_t
const
scale_rows
=
(
gridDim
.
x
+
outer_scale_stride
-
1
)
/
outer_scale_stride
*
outer_scale_stride
;
all_token_scales
[(
threadIdx
.
x
/
threads_per_group
)
*
scale_rows
+
blockIdx
.
x
]
=
scale
;
}
else
{
all_token_scales
[
blockIdx
.
x
*
num_groups
+
...
...
@@ -142,10 +152,8 @@ __device__ void compute_dynamic_per_token_scales(
}
__syncthreads
();
}
else
{
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
for
(
auto
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
x
=
static_cast
<
float
>
(
input
[
token_offset
+
i
]);
float
x
=
static_cast
<
float
>
(
input
[
input_
token_offset
+
i
]);
if
constexpr
(
has_residual
)
{
x
+=
static_cast
<
float
>
(
residual
[
token_offset
+
i
]);
}
...
...
@@ -180,17 +188,18 @@ __device__ void compute_dynamic_per_token_scales(
template
<
typename
scalar_t
,
typename
scalar_out_t
,
bool
is_scale_inverted
,
bool
has_residual
=
false
,
bool
is_scale_transposed
=
false
>
__device__
void
norm_and_quant
(
scalar_out_t
*
__restrict__
output
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
weight
,
float
const
rms
,
float
*
const
scale
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
,
int32_t
const
group_size
=
0
)
{
__device__
void
norm_and_quant
(
scalar_out_t
*
__restrict__
output
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
weight
,
float
const
rms
,
float
*
const
scale
,
int32_t
const
hidden_size
,
int32_t
const
input_stride
,
scalar_t
*
__restrict__
residual
=
nullptr
,
int32_t
const
group_size
=
0
,
int64_t
outer_scale_stride
=
1
)
{
int64_t
const
input_token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
input_stride
);
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
for
(
auto
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
x
=
static_cast
<
float
>
(
input
[
token_offset
+
i
]);
float
x
=
static_cast
<
float
>
(
input
[
input_
token_offset
+
i
]);
if
constexpr
(
has_residual
)
{
x
+=
static_cast
<
float
>
(
residual
[
token_offset
+
i
]);
residual
[
token_offset
+
i
]
=
static_cast
<
scalar_t
>
(
x
);
...
...
@@ -202,7 +211,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
int64_t
scale_idx
=
0
;
if
(
group_size
>
0
)
{
if
constexpr
(
is_scale_transposed
)
{
scale_idx
=
(
i
/
group_size
)
*
gridDim
.
x
+
blockIdx
.
x
;
int64_t
const
scale_rows
=
(
gridDim
.
x
+
outer_scale_stride
-
1
)
/
outer_scale_stride
*
outer_scale_stride
;
scale_idx
=
(
i
/
group_size
)
*
scale_rows
+
blockIdx
.
x
;
}
else
{
scale_idx
=
blockIdx
.
x
*
(
hidden_size
/
group_size
)
+
i
/
group_size
;
}
...
...
@@ -222,13 +233,16 @@ namespace vectorized {
// hidden_size must be a multiple of 4
template
<
typename
scalar_t
,
bool
has_residual
=
false
>
__device__
void
compute_rms
(
float
*
rms
,
scalar_t
const
*
__restrict__
input
,
int32_t
const
hidden_size
,
float
const
epsilon
,
int32_t
const
hidden_size
,
int32_t
const
input_stride
,
float
const
epsilon
,
scalar_t
const
*
__restrict__
residual
=
nullptr
)
{
int64_t
const
input_token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
input_stride
);
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vec_input
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
&
input
[
token_offset
]);
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
&
input
[
input_
token_offset
]);
vec4_t
<
scalar_t
>
const
*
vec_residual
=
nullptr
;
if
constexpr
(
has_residual
)
{
vec_residual
=
...
...
@@ -286,8 +300,9 @@ __device__ void compute_dynamic_per_token_scales(
float
*
__restrict__
token_scale
,
float
*
__restrict__
all_token_scales
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
weight
,
float
const
rms
,
float
const
*
__restrict__
scale_ub
,
int32_t
const
hidden_size
,
scalar_t
const
*
__restrict__
residual
=
nullptr
)
{
int32_t
const
hidden_size
,
int32_t
const
input_stride
,
scalar_t
const
*
__restrict__
residual
=
nullptr
,
int64_t
outer_scale_stride
=
1
)
{
constexpr
scalar_out_t
qmax
{
quant_type_max_v
<
scalar_out_t
>
};
const
int
VEC_SIZE
=
4
;
...
...
@@ -298,10 +313,13 @@ __device__ void compute_dynamic_per_token_scales(
vec4_t
<
scalar_t
>
const
*
vec_weight
=
nullptr
;
vec4_t
<
scalar_t
>
const
*
vec_residual
=
nullptr
;
int64_t
const
input_token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
input_stride
);
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
if
constexpr
(
group_size
>
0
)
{
__shared__
float
s_max_vals
[
1024
];
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
int64_t
const
num_groups
=
hidden_size
/
group_size
;
int64_t
const
threads_per_group
=
blockDim
.
x
/
num_groups
;
int64_t
const
thread_in_group
=
threadIdx
.
x
%
threads_per_group
;
...
...
@@ -310,7 +328,8 @@ __device__ void compute_dynamic_per_token_scales(
int64_t
const
thread_offset
=
group_offset
+
thread_in_group
;
int64_t
const
thread_end
=
min
(
group_offset
+
(
group_size
>>
2
),
static_cast
<
int64_t
>
(
hidden_size
>>
2
));
vec_input
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
&
input
[
token_offset
]);
vec_input
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
&
input
[
input_token_offset
]);
vec_weight
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
weight
);
if
constexpr
(
has_residual
)
{
vec_residual
=
...
...
@@ -382,7 +401,9 @@ __device__ void compute_dynamic_per_token_scales(
scale
=
max
(
scale
/
qmax
,
min_scaling_factor
<
scalar_out_t
>::
val
());
// Global output store
if
constexpr
(
is_scale_transposed
)
{
all_token_scales
[(
threadIdx
.
x
/
threads_per_group
)
*
gridDim
.
x
+
int64_t
const
scale_rows
=
(
gridDim
.
x
+
outer_scale_stride
-
1
)
/
outer_scale_stride
*
outer_scale_stride
;
all_token_scales
[(
threadIdx
.
x
/
threads_per_group
)
*
scale_rows
+
blockIdx
.
x
]
=
scale
;
}
else
{
all_token_scales
[
blockIdx
.
x
*
num_groups
+
...
...
@@ -392,8 +413,8 @@ __device__ void compute_dynamic_per_token_scales(
__syncthreads
();
}
else
{
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
vec_input
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
&
input
[
token_offset
]);
vec_input
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
&
input
[
input_
token_offset
]);
vec_weight
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
weight
);
if
constexpr
(
has_residual
)
{
vec_residual
=
...
...
@@ -458,17 +479,18 @@ __device__ void compute_dynamic_per_token_scales(
template
<
typename
scalar_t
,
typename
scalar_out_t
,
bool
is_scale_inverted
,
bool
has_residual
=
false
,
bool
is_scale_transposed
=
false
,
int32_t
group_size
=
0
>
__device__
void
norm_and_quant
(
scalar_out_t
*
__restrict__
output
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
weight
,
float
const
rms
,
float
*
const
scale
,
int32_t
const
hidden_size
,
scalar_t
*
__restrict__
residual
=
nullptr
)
{
__device__
void
norm_and_quant
(
scalar_out_t
*
__restrict__
output
,
scalar_t
const
*
__restrict__
input
,
scalar_t
const
*
__restrict__
weight
,
float
const
rms
,
float
*
const
scale
,
int32_t
const
hidden_size
,
int32_t
const
input_stride
,
scalar_t
*
__restrict__
residual
=
nullptr
,
int64_t
outer_scale_stride
=
1
)
{
int64_t
const
input_token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
input_stride
);
int64_t
const
token_offset
=
blockIdx
.
x
*
static_cast
<
int64_t
>
(
hidden_size
);
// Vectorized input/output/weight/residual to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vec_input
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
&
input
[
token_offset
]);
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
&
input
[
input_
token_offset
]);
vec4_t
<
scalar_t
>
const
*
vec_weight
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
weight
);
q8x4_t
<
scalar_out_t
>*
vec_output
=
...
...
@@ -516,7 +538,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
int64_t
const
num_groups
=
hidden_size
/
group_size
;
int64_t
scale_idx
=
0
;
if
constexpr
(
is_scale_transposed
)
{
scale_idx
=
(
i
*
VEC_SIZE
/
group_size
)
*
gridDim
.
x
+
blockIdx
.
x
;
int64_t
const
scale_rows
=
(
gridDim
.
x
+
outer_scale_stride
-
1
)
/
outer_scale_stride
*
outer_scale_stride
;
scale_idx
=
(
i
*
VEC_SIZE
/
group_size
)
*
scale_rows
+
blockIdx
.
x
;
}
else
{
scale_idx
=
blockIdx
.
x
*
num_groups
+
i
*
VEC_SIZE
/
group_size
;
}
...
...
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh
View file @
3fb4b5fa
...
...
@@ -12,6 +12,68 @@ namespace vllm {
using
c3x
::
cutlass_gemm_caller
;
// Custom wrapper to allow specifying EpilogueTile for small M
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
,
typename
EpilogueTile
>
struct
cutlass_3x_gemm_sm120_custom
{
using
ElementAB
=
ElementAB_
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementD_
>::
value
;
using
ElementD
=
ElementD_
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
AlignmentC
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
TileShape
>
;
// MMA type
using
ElementAccumulator
=
float
;
// Epilogue types
using
ElementBias
=
cutlass
::
half_t
;
using
ElementCompute
=
float
;
using
ElementAux
=
ElementD
;
using
LayoutAux
=
LayoutD
;
using
ElementAmax
=
float
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm120
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
EpilogueTile
,
// Use custom EpilogueTile
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm120
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementAB
,
LayoutA
,
AlignmentA
,
ElementAB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
,
void
>::
CollectiveOp
;
using
GemmKernel
=
enable_sm120_only
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm120_fp8_config_default
{
...
...
@@ -25,6 +87,54 @@ struct sm120_fp8_config_default {
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm120_fp8_config_M64
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
// SM120 Cooperative kernel requires Tile M >= 128.
// For M=64 tile, we use Pingpong schedule which is more flexible with small
// tiles.
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_64
,
_64
,
_128
>
;
// CUTLASS 3.x on SM120 currently restricts programmatic multicast (Cluster >
// 1) for certain schedules/types. Reverting to 1x1x1 to ensure compilation.
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm120
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm120_fp8_config_M32
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_32
,
_64
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Use custom gemm to specify EpilogueTile M=32
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm120_custom
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
,
Shape
<
_32
,
_32
>>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm120_fp8_config_M16
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_16
,
_64
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Use custom gemm to specify EpilogueTile M=16
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm120_custom
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
,
Shape
<
_16
,
_32
>>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
...
...
@@ -36,6 +146,28 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
int
M
=
a
.
size
(
0
);
if
(
M
<=
16
)
{
using
Cutlass3xGemmM16
=
typename
sm120_fp8_config_M16
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
cutlass_gemm_caller
<
Cutlass3xGemmM16
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
if
(
M
<=
32
)
{
using
Cutlass3xGemmM32
=
typename
sm120_fp8_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
cutlass_gemm_caller
<
Cutlass3xGemmM32
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
if
(
M
<=
256
)
{
using
Cutlass3xGemmM64
=
typename
sm120_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
using
Cutlass3xGemmDefault
=
typename
sm120_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
...
...
@@ -64,4 +196,4 @@ void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
}
}
}
// namespace vllm
\ No newline at end of file
}
// namespace vllm
Prev
1
…
6
7
8
9
10
11
12
13
14
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment