Unverified Commit f2e521c4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Dtype] Align dtype casting behavior with Transformers and Accelerate (#1725)

* [Dtype] Align automatic dtype

* up

* up

* fix

* re-add accelerate
parent debc74f4
...@@ -62,6 +62,7 @@ jobs: ...@@ -62,6 +62,7 @@ jobs:
run: | run: |
python -m pip install -e .[quality,test] python -m pip install -e .[quality,test]
python -m pip install -U git+https://github.com/huggingface/transformers python -m pip install -U git+https://github.com/huggingface/transformers
python -m pip install git+https://github.com/huggingface/accelerate
- name: Environment - name: Environment
run: | run: |
...@@ -134,6 +135,7 @@ jobs: ...@@ -134,6 +135,7 @@ jobs:
${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install --upgrade pip
${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
- name: Environment - name: Environment
shell: arch -arch arm64 bash {0} shell: arch -arch arm64 bash {0}
...@@ -157,4 +159,4 @@ jobs: ...@@ -157,4 +159,4 @@ jobs:
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v2
with: with:
name: torch_mps_test_reports name: torch_mps_test_reports
path: reports path: reports
\ No newline at end of file
...@@ -60,6 +60,7 @@ jobs: ...@@ -60,6 +60,7 @@ jobs:
apt-get update && apt-get install libsndfile1-dev -y apt-get update && apt-get install libsndfile1-dev -y
python -m pip install -e .[quality,test] python -m pip install -e .[quality,test]
python -m pip install -U git+https://github.com/huggingface/transformers python -m pip install -U git+https://github.com/huggingface/transformers
python -m pip install git+https://github.com/huggingface/accelerate
- name: Environment - name: Environment
run: | run: |
...@@ -126,6 +127,7 @@ jobs: ...@@ -126,6 +127,7 @@ jobs:
${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install --upgrade pip
${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers ${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment - name: Environment
......
...@@ -62,6 +62,7 @@ jobs: ...@@ -62,6 +62,7 @@ jobs:
run: | run: |
python -m pip install -e .[quality,test] python -m pip install -e .[quality,test]
python -m pip install -U git+https://github.com/huggingface/transformers python -m pip install -U git+https://github.com/huggingface/transformers
python -m pip install git+https://github.com/huggingface/accelerate
- name: Environment - name: Environment
run: | run: |
...@@ -130,6 +131,7 @@ jobs: ...@@ -130,6 +131,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install -e .[quality,test,training] python -m pip install -e .[quality,test,training]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment - name: Environment
...@@ -151,4 +153,4 @@ jobs: ...@@ -151,4 +153,4 @@ jobs:
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v2
with: with:
name: examples_test_reports name: examples_test_reports
path: reports path: reports
\ No newline at end of file
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import os import os
from functools import partial from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
...@@ -489,11 +490,15 @@ class ModelMixin(torch.nn.Module): ...@@ -489,11 +490,15 @@ class ModelMixin(torch.nn.Module):
state_dict = load_state_dict(model_file) state_dict = load_state_dict(model_file)
# move the parms from meta device to cpu # move the parms from meta device to cpu
for param_name, param in state_dict.items(): for param_name, param in state_dict.items():
set_module_tensor_to_device(model, param_name, param_device, value=param) accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)
else: # else let accelerate handle loading and dispatching. else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map # Load weights and dispatch according to the device_map
# by deafult the device_map is None and the weights are loaded on the CPU # by deafult the device_map is None and the weights are loaded on the CPU
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map) accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
loading_info = { loading_info = {
"missing_keys": [], "missing_keys": [],
...@@ -519,20 +524,6 @@ class ModelMixin(torch.nn.Module): ...@@ -519,20 +524,6 @@ class ModelMixin(torch.nn.Module):
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file) state_dict = load_state_dict(model_file)
dtype = set(v.dtype for v in state_dict.values())
if len(dtype) > 1 and torch.float32 not in dtype:
raise ValueError(
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
f" make sure that {model_file} weights have only one dtype."
)
elif len(dtype) > 1 and torch.float32 in dtype:
dtype = torch.float32
else:
dtype = dtype.pop()
# move model to correct dtype
model = model.to(dtype)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model, model,
......
...@@ -70,9 +70,9 @@ class ModelTesterMixin: ...@@ -70,9 +70,9 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.to(dtype) model.to(dtype)
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True) new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype assert new_model.dtype == dtype
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False) new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
assert new_model.dtype == dtype assert new_model.dtype == dtype
def test_determinism(self): def test_determinism(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment