Unverified Commit 00ec9ff1 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[fix] Add the latest numpy version to requirements.txt (#731)

* set numpy version

* remove numpy requirement

* remove numpy plugin

* add numpy requirements
parent 4a63034e
...@@ -533,7 +533,7 @@ class AdaScale(Optimizer): ...@@ -533,7 +533,7 @@ class AdaScale(Optimizer):
continue continue
# must be a np array, extend it with the right value and check the shape. # must be a np array, extend it with the right value and check the shape.
val = 1 if name == "grad_sqr_avg" else 0 val = 1 if name == "grad_sqr_avg" else 0
self._state[name] = np.append(self._state[name], val) self._state[name] = np.append(self._state[name], val) # type: ignore
assert self._state[name].shape == (len(self._optimizer.param_groups),) assert self._state[name].shape == (len(self._optimizer.param_groups),)
def zero_grad(self) -> None: def zero_grad(self) -> None:
......
# FairScale should only depends on torch, not things higher level than torch. # FairScale should only depends on torch, not things higher level than torch.
torch >= 1.6.0 torch >= 1.6.0
numpy >= 1.21
\ No newline at end of file
...@@ -42,6 +42,7 @@ exclude = build,*.pyi,.git ...@@ -42,6 +42,7 @@ exclude = build,*.pyi,.git
[mypy] [mypy]
mypy_path = ./stubs/ mypy_path = ./stubs/
follow_imports = normal follow_imports = normal
plugins = numpy.typing.mypy_plugin
# This project must be strictly typed. # This project must be strictly typed.
[mypy-fairscale.*] [mypy-fairscale.*]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment