During the above build process, the cutlass source code will be downloaded automatically. If you have already downloaded the source code, you can specify the local cutlass path:
usingElementA=cutlass::nv_float4_t<cutlass::float_e2m1_t>;// Element type for A matrix operand
usingLayoutATag=cutlass::layout::RowMajor;// Layout type for A matrix operand
staticconstexprintAlignmentA=32;// Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
usingElementB=cutlass::nv_float4_t<cutlass::float_e2m1_t>;// Element type for B matrix operand
usingLayoutBTag=cutlass::layout::ColumnMajor;// Layout type for B matrix operand
staticconstexprintAlignmentB=32;// Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
usingElementD=cutlass::bfloat16_t;// Element type for D matrix operand
usingElementC=cutlass::bfloat16_t;// Element type for C matrix operand
usingLayoutCTag=cutlass::layout::RowMajor;// Layout type for C matrix operand
usingLayoutDTag=cutlass::layout::RowMajor;// Layout type for D matrix operand
staticconstexprintAlignmentD=128/cutlass::sizeof_bits<ElementD>::value;// Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
staticconstexprintAlignmentC=128/cutlass::sizeof_bits<ElementC>::value;// Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
usingElementAccumulator=float;// Element type for internal accumulation
usingArchTag=cutlass::arch::Sm120;// Tag indicating the minimum SM that supports the intended feature
usingOperatorClass=cutlass::arch::OpClassBlockScaledTensorOp;// Operator class tag
usingLayoutSFA=typenameGemm::GemmKernel::CollectiveMainloop::LayoutSFA;// Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
usingLayoutSFB=typenameGemm::GemmKernel::CollectiveMainloop::LayoutSFB;// Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
Where `b = 2^(e-1) - 1`, p represents the value of the exponent bits, d1, d2, d3 represent the values of the mantissa bits
For fp4, the format is E2M1, and the above formula is simplified to:
`b = 2^(e-1) - 1 = 2^(2-1) - 1 = 1`
`ans = (-1)^s * 2^(p-1) * (1 + d1/2)`
Example: 0101
`s=0, p=(10)=2, d1=1`
`ans = 2^0 * 2^(2-1) * (1 + 1/2) = 3`
In normal fp data format, some data represents inf and nan, with a maximum representation of ±3. Specialized for nvfp4, inf and nan are removed, allowing a maximum representation of ±6.
**Both weight and activation use per-group quantization, with a group size of 16, and quantization scales are stored in fp8(e4m3) format**
Since the quantization scale needs to be stored in fp8, the scale also needs to be rescaled, so the fp4 quantization process differs somewhat from the common w8a8-int8 process.
The quantization process is as follows:
Given a set of numbers, denoted as `X`
#### Calculate scale
`scale1 = max(abs(Xg)) / 6.0`
Where Xg represents a group of numbers, and 6.0 represents the maximum value of nvfp4
#### Quantize scale
`global_scale = 6.0 * 448.0 / max(abs(X))`
`scale2 = global_scale * scale1`
That is `scale2 = 6.0 * 448.0 / max(abs(X)) * max(abs(Xg)) / 6.0`
That is `scale2 = max(abs(Xg)) / max(abs(X)) * 448.0`
At this point, scale2 is rescaled to the range of fp8(e4m3), then scale2 is quantized to fp8
`scale2_fp8 = quant_fp8(scale2)`
`scale2_fp8` serves as the final quantization scale parameter required for matrix multiplication