"vscode:/vscode.git/clone" did not exist on "fe743b798dfa56aea3e2cb7182365ba3495489ee"
Unverified Commit c0c98b8b authored by Maral's avatar Maral Committed by GitHub
Browse files

[Bugfix] Add Marlin kernel in block scaled mm kernel selection. (#40105)


Signed-off-by: default avatarmaral <maralbahari.98@gmail.com>
parent 8d2cff81
...@@ -186,12 +186,13 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = ...@@ -186,12 +186,13 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
# in priority/performance order (when available) # in priority/performance order (when available)
_POSSIBLE_FP8_BLOCK_KERNELS: dict[ _POSSIBLE_FP8_BLOCK_KERNELS: dict[
PlatformEnum, list[type[Fp8BlockScaledMMLinearKernel]] PlatformEnum, list[type[Fp8BlockScaledMMLinearKernel | FP8ScaledMMLinearKernel]]
] = { ] = {
PlatformEnum.CUDA: [ PlatformEnum.CUDA: [
FlashInferFp8DeepGEMMDynamicBlockScaledKernel, FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
DeepGemmFp8BlockScaledMMKernel, DeepGemmFp8BlockScaledMMKernel,
CutlassFp8BlockScaledMMKernel, CutlassFp8BlockScaledMMKernel,
MarlinFP8ScaledMMLinearKernel,
TritonFp8BlockScaledMMKernel, TritonFp8BlockScaledMMKernel,
], ],
PlatformEnum.ROCM: [ PlatformEnum.ROCM: [
...@@ -392,6 +393,19 @@ def init_fp8_linear_kernel( ...@@ -392,6 +393,19 @@ def init_fp8_linear_kernel(
scope="global", scope="global",
) )
# TODO make scaled_mm kernels inherit from MMLinearKernel
# only MarlinFP8ScaledMMLinearKernel is a type of FP8ScaledMMLinearKernel.
if issubclass(kernel_type, FP8ScaledMMLinearKernel):
return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=[
"weight",
"weight_scale",
"input_scale",
"input_scale_ub",
],
)
return kernel_type( return kernel_type(
scaled_mm_linear_kernel_config, scaled_mm_linear_kernel_config,
) )
...@@ -399,7 +413,7 @@ def init_fp8_linear_kernel( ...@@ -399,7 +413,7 @@ def init_fp8_linear_kernel(
else: else:
kernel_type = choose_scaled_mm_linear_kernel( kernel_type = choose_scaled_mm_linear_kernel(
config=scaled_mm_linear_kernel_config, config=scaled_mm_linear_kernel_config,
possible_kernels=_POSSIBLE_FP8_KERNELS, # type: ignore[misc] possible_kernels=_POSSIBLE_FP8_KERNELS, # type: ignore[arg-type]
force_kernel=force_kernel, force_kernel=force_kernel,
) )
if module_name: if module_name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment