Unverified Commit a7d53a59 authored by hlky's avatar hlky Committed by GitHub
Browse files

Don't override `torch_dtype` and don't use when `quantization_config` is set (#11039)



* Don't use `torch_dtype` when `quantization_config` is set

* up

* djkajka

* Apply suggestions from code review

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8a63aa5e
...@@ -360,12 +360,12 @@ class FromSingleFileMixin: ...@@ -360,12 +360,12 @@ class FromSingleFileMixin:
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32) torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False is_legacy_loading = False
if not isinstance(torch_dtype, torch.dtype): if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32 torch_dtype = torch.float32
logger.warning( logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
......
...@@ -255,12 +255,12 @@ class FromOriginalModelMixin: ...@@ -255,12 +255,12 @@ class FromOriginalModelMixin:
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None) config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32) torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None) quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None) device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
if not isinstance(torch_dtype, torch.dtype): if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32 torch_dtype = torch.float32
logger.warning( logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
......
...@@ -880,7 +880,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -880,7 +880,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
local_files_only = kwargs.pop("local_files_only", None) local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32) torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None) max_memory = kwargs.pop("max_memory", None)
...@@ -893,7 +893,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -893,7 +893,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
if not isinstance(torch_dtype, torch.dtype): if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32 torch_dtype = torch.float32
logger.warning( logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
......
...@@ -104,13 +104,6 @@ class RMSNorm(torch.nn.Module): ...@@ -104,13 +104,6 @@ class RMSNorm(torch.nn.Module):
return (self.weight * hidden_states).to(input_dtype) return (self.weight * hidden_states).to(input_dtype)
def _config_to_kwargs(args):
common_kwargs = {
"dtype": args.torch_dtype,
}
return common_kwargs
class CoreAttention(torch.nn.Module): class CoreAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number): def __init__(self, config: ChatGLMConfig, layer_number):
super(CoreAttention, self).__init__() super(CoreAttention, self).__init__()
...@@ -314,7 +307,6 @@ class SelfAttention(torch.nn.Module): ...@@ -314,7 +307,6 @@ class SelfAttention(torch.nn.Module):
self.qkv_hidden_size, self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias, bias=config.add_bias_linear or config.add_qkv_bias,
device=device, device=device,
**_config_to_kwargs(config),
) )
self.core_attention = CoreAttention(config, self.layer_number) self.core_attention = CoreAttention(config, self.layer_number)
...@@ -325,7 +317,6 @@ class SelfAttention(torch.nn.Module): ...@@ -325,7 +317,6 @@ class SelfAttention(torch.nn.Module):
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
device=device, device=device,
**_config_to_kwargs(config),
) )
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
...@@ -449,7 +440,6 @@ class MLP(torch.nn.Module): ...@@ -449,7 +440,6 @@ class MLP(torch.nn.Module):
config.ffn_hidden_size * 2, config.ffn_hidden_size * 2,
bias=self.add_bias, bias=self.add_bias,
device=device, device=device,
**_config_to_kwargs(config),
) )
def swiglu(x): def swiglu(x):
...@@ -459,9 +449,7 @@ class MLP(torch.nn.Module): ...@@ -459,9 +449,7 @@ class MLP(torch.nn.Module):
self.activation_func = swiglu self.activation_func = swiglu
# Project back to h. # Project back to h.
self.dense_4h_to_h = nn.Linear( self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device)
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
)
def forward(self, hidden_states): def forward(self, hidden_states):
# [s, b, 4hp] # [s, b, 4hp]
...@@ -488,18 +476,14 @@ class GLMBlock(torch.nn.Module): ...@@ -488,18 +476,14 @@ class GLMBlock(torch.nn.Module):
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = LayerNormFunc( self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
# Self attention. # Self attention.
self.self_attention = SelfAttention(config, layer_number, device=device) self.self_attention = SelfAttention(config, layer_number, device=device)
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output # Layernorm on the attention output
self.post_attention_layernorm = LayerNormFunc( self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
# MLP # MLP
self.mlp = MLP(config, device=device) self.mlp = MLP(config, device=device)
...@@ -569,9 +553,7 @@ class GLMTransformer(torch.nn.Module): ...@@ -569,9 +553,7 @@ class GLMTransformer(torch.nn.Module):
if self.post_layer_norm: if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNormFunc( self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -679,9 +661,7 @@ class Embedding(torch.nn.Module): ...@@ -679,9 +661,7 @@ class Embedding(torch.nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Word embeddings (parallel). # Word embeddings (parallel).
self.word_embeddings = nn.Embedding( self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device)
config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
)
self.fp32_residual_connection = config.fp32_residual_connection self.fp32_residual_connection = config.fp32_residual_connection
def forward(self, input_ids): def forward(self, input_ids):
...@@ -784,16 +764,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ...@@ -784,16 +764,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
) )
self.rotary_pos_emb = RotaryEmbedding( self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device)
rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
)
self.encoder = init_method(GLMTransformer, config, **init_kwargs) self.encoder = init_method(GLMTransformer, config, **init_kwargs)
self.output_layer = init_method( self.output_layer = init_method(
nn.Linear, nn.Linear,
config.hidden_size, config.hidden_size,
config.padded_vocab_size, config.padded_vocab_size,
bias=False, bias=False,
dtype=config.torch_dtype,
**init_kwargs, **init_kwargs,
) )
self.pre_seq_len = config.pre_seq_len self.pre_seq_len = config.pre_seq_len
......
...@@ -686,7 +686,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -686,7 +686,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
token = kwargs.pop("token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False) from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", torch.float32) torch_dtype = kwargs.pop("torch_dtype", None)
custom_pipeline = kwargs.pop("custom_pipeline", None) custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None) custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None) provider = kwargs.pop("provider", None)
...@@ -703,7 +703,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -703,7 +703,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_onnx = kwargs.pop("use_onnx", None) use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if not isinstance(torch_dtype, torch.dtype): if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32 torch_dtype = torch.float32
logger.warning( logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
...@@ -1456,8 +1456,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1456,8 +1456,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if load_components_from_hub and not trust_remote_code: if load_components_from_hub and not trust_remote_code:
raise ValueError( raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly " f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n" f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
) )
......
...@@ -90,7 +90,7 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -90,7 +90,7 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained( text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16 "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
) )
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
......
...@@ -94,7 +94,7 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -94,7 +94,7 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase):
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained( text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16 "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
) )
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
......
...@@ -99,7 +99,7 @@ class KolorsPAGPipelineFastTests( ...@@ -99,7 +99,7 @@ class KolorsPAGPipelineFastTests(
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained( text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16 "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
) )
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment