Unverified Commit 00ec9ff1 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[fix] Add the latest numpy version to requirements.txt (#731)

* set numpy version

* remove numpy requirement

* remove numpy plugin

* add numpy requirements
parent 4a63034e
......@@ -533,7 +533,7 @@ class AdaScale(Optimizer):
continue
# must be a np array, extend it with the right value and check the shape.
val = 1 if name == "grad_sqr_avg" else 0
self._state[name] = np.append(self._state[name], val)
self._state[name] = np.append(self._state[name], val) # type: ignore
assert self._state[name].shape == (len(self._optimizer.param_groups),)
def zero_grad(self) -> None:
......
# FairScale should only depends on torch, not things higher level than torch.
torch >= 1.6.0
numpy >= 1.21
\ No newline at end of file
......@@ -42,6 +42,7 @@ exclude = build,*.pyi,.git
[mypy]
mypy_path = ./stubs/
follow_imports = normal
plugins = numpy.typing.mypy_plugin
# This project must be strictly typed.
[mypy-fairscale.*]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment