Commit d9a3b7f0 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 348075411
parent be159399
...@@ -47,10 +47,13 @@ class MaskedLMConfig(cfg.TaskConfig): ...@@ -47,10 +47,13 @@ class MaskedLMConfig(cfg.TaskConfig):
class MaskedLMTask(base_task.Task): class MaskedLMTask(base_task.Task):
"""Task object for Mask language modeling.""" """Task object for Mask language modeling."""
def _build_encoder(self, encoder_cfg):
return encoders.build_encoder(encoder_cfg)
def build_model(self, params=None): def build_model(self, params=None):
config = params or self.task_config.model config = params or self.task_config.model
encoder_cfg = config.encoder encoder_cfg = config.encoder
encoder_network = encoders.build_encoder(encoder_cfg) encoder_network = self._build_encoder(encoder_cfg)
cls_heads = [ cls_heads = [
layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
] if config.cls_heads else [] ] if config.cls_heads else []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment