Unverified Commit 253dbfd8 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix SVD error for `numpy >= 1.21` in DNGO (#4466)

parent a8c12fb7
import logging import logging
import warnings
import numpy as np import numpy as np
import torch import torch
...@@ -112,7 +113,10 @@ class DNGOTuner(Tuner): ...@@ -112,7 +113,10 @@ class DNGOTuner(Tuner):
x_arr = [] x_arr = []
for x in self.x: for x in self.x:
x_arr.append([x[k] for k in sorted(x.keys())]) x_arr.append([x[k] for k in sorted(x.keys())])
self.model.train(np.array(x_arr), np.array(self.y), do_optimize=True) try:
self.model.train(np.array(x_arr), np.array(self.y), do_optimize=True)
except np.linalg.LinAlgError as e:
warnings.warn(f'numpy linalg error encountered in DNGO model training: {e}')
self._model_initialized = True self._model_initialized = True
def _get_default_value(self, value): def _get_default_value(self, value):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment