Unverified Commit c0c98b8b authored by Maral's avatar Maral Committed by GitHub
Browse files

[Bugfix] Add Marlin kernel in block scaled mm kernel selection. (#40105)


Signed-off-by: default avatarmaral <maralbahari.98@gmail.com>
parent 8d2cff81
......@@ -186,12 +186,13 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
# in priority/performance order (when available)
_POSSIBLE_FP8_BLOCK_KERNELS: dict[
PlatformEnum, list[type[Fp8BlockScaledMMLinearKernel]]
PlatformEnum, list[type[Fp8BlockScaledMMLinearKernel | FP8ScaledMMLinearKernel]]
] = {
PlatformEnum.CUDA: [
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
DeepGemmFp8BlockScaledMMKernel,
CutlassFp8BlockScaledMMKernel,
MarlinFP8ScaledMMLinearKernel,
TritonFp8BlockScaledMMKernel,
],
PlatformEnum.ROCM: [
......@@ -392,6 +393,19 @@ def init_fp8_linear_kernel(
scope="global",
)
# TODO make scaled_mm kernels inherit from MMLinearKernel
# only MarlinFP8ScaledMMLinearKernel is a type of FP8ScaledMMLinearKernel.
if issubclass(kernel_type, FP8ScaledMMLinearKernel):
return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=[
"weight",
"weight_scale",
"input_scale",
"input_scale_ub",
],
)
return kernel_type(
scaled_mm_linear_kernel_config,
)
......@@ -399,7 +413,7 @@ def init_fp8_linear_kernel(
else:
kernel_type = choose_scaled_mm_linear_kernel(
config=scaled_mm_linear_kernel_config,
possible_kernels=_POSSIBLE_FP8_KERNELS, # type: ignore[misc]
possible_kernels=_POSSIBLE_FP8_KERNELS, # type: ignore[arg-type]
force_kernel=force_kernel,
)
if module_name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment