Unverified Commit 683e3cb9 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[ Misc ] `fbgemm` checkpoints (#6559)

parent 9042d683
...@@ -100,8 +100,8 @@ class MarlinConfig(QuantizationConfig): ...@@ -100,8 +100,8 @@ class MarlinConfig(QuantizationConfig):
return None return None
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: prefix: str) -> Optional["MarlinLinearMethod"]:
if (isinstance(layer, LinearBase) or if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return MarlinLinearMethod(self) return MarlinLinearMethod(self)
......
...@@ -52,8 +52,8 @@ class SqueezeLLMConfig(QuantizationConfig): ...@@ -52,8 +52,8 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["wbits"]) weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits) return cls(weight_bits)
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: prefix: str) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self) return SqueezeLLMLinearMethod(self)
return None return None
......
...@@ -105,6 +105,7 @@ def apply_fp8_linear( ...@@ -105,6 +105,7 @@ def apply_fp8_linear(
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: torch.Tensor, input_scale: torch.Tensor,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True, cutlass_fp8_supported: bool = True,
use_per_token_if_dynamic: bool = False, use_per_token_if_dynamic: bool = False,
...@@ -118,6 +119,7 @@ def apply_fp8_linear( ...@@ -118,6 +119,7 @@ def apply_fp8_linear(
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = ops.scaled_fp8_quant(
input, input,
input_scale, input_scale,
scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic) use_per_token_if_dynamic=use_per_token_if_dynamic)
# Fused GEMM_DQ # Fused GEMM_DQ
......
...@@ -161,6 +161,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -161,6 +161,7 @@ class VocabParallelEmbedding(torch.nn.Module):
org_num_embeddings: original vocabulary size (without LoRA). org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary. padding_size: padding size for the vocabulary.
quant_config: quant config for the layer quant_config: quant config for the layer
prefix: full name of the layer in the state dict
""" # noqa: E501 """ # noqa: E501
def __init__(self, def __init__(self,
...@@ -169,7 +170,8 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -169,7 +170,8 @@ class VocabParallelEmbedding(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None, org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
# Keep the input dimensions. # Keep the input dimensions.
...@@ -195,7 +197,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -195,7 +197,7 @@ class VocabParallelEmbedding(torch.nn.Module):
linear_method = None linear_method = None
if quant_config is not None: if quant_config is not None:
linear_method = quant_config.get_quant_method(self) linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None: if linear_method is None:
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method: QuantizeMethodBase = linear_method self.linear_method: QuantizeMethodBase = linear_method
...@@ -382,9 +384,11 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -382,9 +384,11 @@ class ParallelLMHead(VocabParallelEmbedding):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None, org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype, super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config) org_num_embeddings, padding_size, quant_config,
prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment