Unverified Commit a7d53a59 authored by hlky's avatar hlky Committed by GitHub
Browse files

Don't override `torch_dtype` and don't use when `quantization_config` is set (#11039)



* Don't use `torch_dtype` when `quantization_config` is set

* up

* djkajka

* Apply suggestions from code review

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8a63aa5e
......@@ -360,12 +360,12 @@ class FromSingleFileMixin:
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False
if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
......
......@@ -255,12 +255,12 @@ class FromOriginalModelMixin:
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
......
......@@ -880,7 +880,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
......@@ -893,7 +893,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
......
......@@ -104,13 +104,6 @@ class RMSNorm(torch.nn.Module):
return (self.weight * hidden_states).to(input_dtype)
def _config_to_kwargs(args):
common_kwargs = {
"dtype": args.torch_dtype,
}
return common_kwargs
class CoreAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number):
super(CoreAttention, self).__init__()
......@@ -314,7 +307,6 @@ class SelfAttention(torch.nn.Module):
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device,
**_config_to_kwargs(config),
)
self.core_attention = CoreAttention(config, self.layer_number)
......@@ -325,7 +317,6 @@ class SelfAttention(torch.nn.Module):
config.hidden_size,
bias=config.add_bias_linear,
device=device,
**_config_to_kwargs(config),
)
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
......@@ -449,7 +440,6 @@ class MLP(torch.nn.Module):
config.ffn_hidden_size * 2,
bias=self.add_bias,
device=device,
**_config_to_kwargs(config),
)
def swiglu(x):
......@@ -459,9 +449,7 @@ class MLP(torch.nn.Module):
self.activation_func = swiglu
# Project back to h.
self.dense_4h_to_h = nn.Linear(
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
)
self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device)
def forward(self, hidden_states):
# [s, b, 4hp]
......@@ -488,18 +476,14 @@ class GLMBlock(torch.nn.Module):
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
# Self attention.
self.self_attention = SelfAttention(config, layer_number, device=device)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
# MLP
self.mlp = MLP(config, device=device)
......@@ -569,9 +553,7 @@ class GLMTransformer(torch.nn.Module):
if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
self.gradient_checkpointing = False
......@@ -679,9 +661,7 @@ class Embedding(torch.nn.Module):
self.hidden_size = config.hidden_size
# Word embeddings (parallel).
self.word_embeddings = nn.Embedding(
config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
)
self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device)
self.fp32_residual_connection = config.fp32_residual_connection
def forward(self, input_ids):
......@@ -784,16 +764,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device)
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
self.output_layer = init_method(
nn.Linear,
config.hidden_size,
config.padded_vocab_size,
bias=False,
dtype=config.torch_dtype,
**init_kwargs,
)
self.pre_seq_len = config.pre_seq_len
......
......@@ -686,7 +686,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None)
......@@ -703,7 +703,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
......@@ -1456,8 +1456,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if load_components_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
......
......@@ -90,7 +90,7 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
......
......@@ -94,7 +94,7 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase):
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
......
......@@ -99,7 +99,7 @@ class KolorsPAGPipelineFastTests(
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment