Commit 4a577082 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Specify loader for yaml loading.

PiperOrigin-RevId: 330119437
parent a9fcda17
...@@ -227,7 +227,7 @@ class Config(params_dict.ParamsDict): ...@@ -227,7 +227,7 @@ class Config(params_dict.ParamsDict):
def from_yaml(cls, file_path: str): def from_yaml(cls, file_path: str):
# Note: This only works if the Config has all default values. # Note: This only works if the Config has all default values.
with tf.io.gfile.GFile(file_path, 'r') as f: with tf.io.gfile.GFile(file_path, 'r') as f:
loaded = yaml.load(f) loaded = yaml.load(f, Loader=yaml.FullLoader)
config = cls() config = cls()
config.override(loaded) config.override(loaded)
return config return config
......
...@@ -317,7 +317,7 @@ class ParamsDict(object): ...@@ -317,7 +317,7 @@ class ParamsDict(object):
def read_yaml_to_params_dict(file_path): def read_yaml_to_params_dict(file_path):
"""Reads a YAML file to a ParamsDict.""" """Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f: with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f) params_dict = yaml.load(f, Loader=yaml.FullLoader)
return ParamsDict(params_dict) return ParamsDict(params_dict)
...@@ -438,12 +438,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict): ...@@ -438,12 +438,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
nested_csv_str_to_json_str(dict_or_string_or_yaml_file)) nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError: except ValueError:
pass pass
params_dict = yaml.load(dict_or_string_or_yaml_file) params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=yaml.FullLoader)
if isinstance(params_dict, dict): if isinstance(params_dict, dict):
params.override(params_dict, is_strict) params.override(params_dict, is_strict)
else: else:
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f: with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
params.override(yaml.load(f), is_strict) params.override(yaml.load(f, Loader=yaml.FullLoader), is_strict)
else: else:
raise ValueError('Unknown input type to parse.') raise ValueError('Unknown input type to parse.')
return params return params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment