"docs/source/vscode:/vscode.git/clone" did not exist on "269e73b6011f87f4d5e6fea47f8fee11dfcdf2cc"
Unverified Commit 79b1c696 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Pytorch 1.5.0 (#3973)

* Standard deviation can no longer be set to 0

* Remove torch pinned version

* 9th instead of 10th, silly me
parent 818463ee
...@@ -67,7 +67,7 @@ extras["mecab"] = ["mecab-python3"] ...@@ -67,7 +67,7 @@ extras["mecab"] = ["mecab-python3"]
extras["sklearn"] = ["scikit-learn"] extras["sklearn"] = ["scikit-learn"]
extras["tf"] = ["tensorflow"] extras["tf"] = ["tensorflow"]
extras["tf-cpu"] = ["tensorflow-cpu"] extras["tf-cpu"] = ["tensorflow-cpu"]
extras["torch"] = ["torch==1.4.0"] extras["torch"] = ["torch"]
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"] extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
extras["all"] = extras["serving"] + ["tensorflow", "torch"] extras["all"] = extras["serving"] + ["tensorflow", "torch"]
......
...@@ -45,7 +45,7 @@ def _config_zero_init(config): ...@@ -45,7 +45,7 @@ def _config_zero_init(config):
configs_no_init = copy.deepcopy(config) configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys(): for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key: if "_range" in key or "_std" in key or "initializer_factor" in key:
setattr(configs_no_init, key, 0.0) setattr(configs_no_init, key, 1e-10)
return configs_no_init return configs_no_init
...@@ -96,7 +96,7 @@ class ModelTesterMixin: ...@@ -96,7 +96,7 @@ class ModelTesterMixin:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.requires_grad: if param.requires_grad:
self.assertIn( self.assertIn(
param.data.mean().item(), ((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0], [0.0, 1.0],
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class), msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment