Unverified Commit e8b88a79 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

unify name speed up and speedup to speedup (#4689)

parent c5066cda
...@@ -26,10 +26,10 @@ class PruningScheduler(BasePruningScheduler): ...@@ -26,10 +26,10 @@ class PruningScheduler(BasePruningScheduler):
finetuner finetuner
The finetuner handled all finetune logic, use a pytorch module as input. The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise. It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
speed_up speedup
If set True, speed up the model at the end of each iteration to make the pruned model compact. If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input dummy_input
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up. If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
evaluator evaluator
Evaluate the pruned model and give a score. Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result. If evaluator is None, the best result refers to the latest result.
...@@ -37,12 +37,12 @@ class PruningScheduler(BasePruningScheduler): ...@@ -37,12 +37,12 @@ class PruningScheduler(BasePruningScheduler):
If set True, the model weight will reset to the origin model weight at the end of each iteration step. If set True, the model weight will reset to the origin model weight at the end of each iteration step.
""" """
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None, def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None,
speed_up: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None, speedup: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None,
reset_weight: bool = False): reset_weight: bool = False):
self.pruner = pruner self.pruner = pruner
self.task_generator = task_generator self.task_generator = task_generator
self.finetuner = finetuner self.finetuner = finetuner
self.speed_up = speed_up self.speedup = speedup
self.dummy_input = dummy_input self.dummy_input = dummy_input
self.evaluator = evaluator self.evaluator = evaluator
self.reset_weight = reset_weight self.reset_weight = reset_weight
...@@ -58,7 +58,7 @@ class PruningScheduler(BasePruningScheduler): ...@@ -58,7 +58,7 @@ class PruningScheduler(BasePruningScheduler):
def pruning_one_step_normal(self, task: Task) -> TaskResult: def pruning_one_step_normal(self, task: Task) -> TaskResult:
""" """
generate masks -> speed up -> finetune -> evaluate generate masks -> speedup -> finetune -> evaluate
""" """
model, masks, config_list = task.load_data() model, masks, config_list = task.load_data()
self.pruner.reset(model, config_list) self.pruner.reset(model, config_list)
...@@ -72,14 +72,14 @@ class PruningScheduler(BasePruningScheduler): ...@@ -72,14 +72,14 @@ class PruningScheduler(BasePruningScheduler):
self.pruner.show_pruned_weights() self.pruner.show_pruned_weights()
self.pruner._unwrap_model() self.pruner._unwrap_model()
# speed up # speedup
if self.speed_up and task.speed_up: if self.speedup and task.speedup:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model() ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {} compact_model_masks = {}
# finetune # finetune
if self.finetuner is not None and task.finetune: if self.finetuner is not None and task.finetune:
if self.speed_up: if self.speedup:
self.finetuner(compact_model) self.finetuner(compact_model)
else: else:
self.pruner._wrap_model() self.pruner._wrap_model()
...@@ -88,7 +88,7 @@ class PruningScheduler(BasePruningScheduler): ...@@ -88,7 +88,7 @@ class PruningScheduler(BasePruningScheduler):
# evaluate # evaluate
if self.evaluator is not None and task.evaluate: if self.evaluator is not None and task.evaluate:
if self.speed_up: if self.speedup:
score = self.evaluator(compact_model) score = self.evaluator(compact_model)
else: else:
self.pruner._wrap_model() self.pruner._wrap_model()
...@@ -104,7 +104,7 @@ class PruningScheduler(BasePruningScheduler): ...@@ -104,7 +104,7 @@ class PruningScheduler(BasePruningScheduler):
def pruning_one_step_reset_weight(self, task: Task) -> TaskResult: def pruning_one_step_reset_weight(self, task: Task) -> TaskResult:
""" """
finetune -> generate masks -> reset weight -> speed up -> evaluate finetune -> generate masks -> reset weight -> speedup -> evaluate
""" """
model, masks, config_list = task.load_data() model, masks, config_list = task.load_data()
checkpoint = deepcopy(model.state_dict()) checkpoint = deepcopy(model.state_dict())
...@@ -126,14 +126,14 @@ class PruningScheduler(BasePruningScheduler): ...@@ -126,14 +126,14 @@ class PruningScheduler(BasePruningScheduler):
# reset model weight # reset model weight
compact_model.load_state_dict(checkpoint) compact_model.load_state_dict(checkpoint)
# speed up # speedup
if self.speed_up and task.speed_up: if self.speedup and task.speedup:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model() ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {} compact_model_masks = {}
# evaluate # evaluate
if self.evaluator is not None and task.evaluate: if self.evaluator is not None and task.evaluate:
if self.speed_up: if self.speedup:
score = self.evaluator(compact_model) score = self.evaluator(compact_model)
else: else:
self.pruner._wrap_model() self.pruner._wrap_model()
......
...@@ -93,10 +93,10 @@ class LinearPruner(IterativePruner): ...@@ -93,10 +93,10 @@ class LinearPruner(IterativePruner):
finetuner : Optional[Callable[[Module], None]] finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input. The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration. It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speed_up : bool speedup : bool
If set True, speed up the model at the end of each iteration to make the pruned model compact. If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor] dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up. If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
evaluator : Optional[Callable[[Module], float]] evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score. Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result. If evaluator is None, the best result refers to the latest result.
...@@ -117,7 +117,7 @@ class LinearPruner(IterativePruner): ...@@ -117,7 +117,7 @@ class LinearPruner(IterativePruner):
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None, finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}): evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}):
task_generator = LinearTaskGenerator(total_iteration=total_iteration, task_generator = LinearTaskGenerator(total_iteration=total_iteration,
origin_model=model, origin_model=model,
...@@ -127,7 +127,7 @@ class LinearPruner(IterativePruner): ...@@ -127,7 +127,7 @@ class LinearPruner(IterativePruner):
if 'traced_optimizer' in pruning_params: if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) evaluator=evaluator, reset_weight=False)
...@@ -158,10 +158,10 @@ class AGPPruner(IterativePruner): ...@@ -158,10 +158,10 @@ class AGPPruner(IterativePruner):
finetuner : Optional[Callable[[Module], None]] finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input. The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration. It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speed_up : bool speedup : bool
If set True, speed up the model at the end of each iteration to make the pruned model compact. If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor] dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up. If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
evaluator : Optional[Callable[[Module], float]] evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score. Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result. If evaluator is None, the best result refers to the latest result.
...@@ -182,7 +182,7 @@ class AGPPruner(IterativePruner): ...@@ -182,7 +182,7 @@ class AGPPruner(IterativePruner):
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None, finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}): evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}):
task_generator = AGPTaskGenerator(total_iteration=total_iteration, task_generator = AGPTaskGenerator(total_iteration=total_iteration,
origin_model=model, origin_model=model,
...@@ -192,7 +192,7 @@ class AGPPruner(IterativePruner): ...@@ -192,7 +192,7 @@ class AGPPruner(IterativePruner):
if 'traced_optimizer' in pruning_params: if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) evaluator=evaluator, reset_weight=False)
...@@ -234,10 +234,10 @@ class LotteryTicketPruner(IterativePruner): ...@@ -234,10 +234,10 @@ class LotteryTicketPruner(IterativePruner):
finetuner : Optional[Callable[[Module], None]] finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input. The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise. It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
speed_up : bool speedup : bool
If set True, speed up the model at the end of each iteration to make the pruned model compact. If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor] dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up. If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
evaluator : Optional[Callable[[Module], float]] evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score. Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result. If evaluator is None, the best result refers to the latest result.
...@@ -261,7 +261,7 @@ class LotteryTicketPruner(IterativePruner): ...@@ -261,7 +261,7 @@ class LotteryTicketPruner(IterativePruner):
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None, finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, reset_weight: bool = True, evaluator: Optional[Callable[[Module], float]] = None, reset_weight: bool = True,
pruning_params: Dict = {}): pruning_params: Dict = {}):
task_generator = LotteryTicketTaskGenerator(total_iteration=total_iteration, task_generator = LotteryTicketTaskGenerator(total_iteration=total_iteration,
...@@ -272,7 +272,7 @@ class LotteryTicketPruner(IterativePruner): ...@@ -272,7 +272,7 @@ class LotteryTicketPruner(IterativePruner):
if 'traced_optimizer' in pruning_params: if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=reset_weight) evaluator=evaluator, reset_weight=reset_weight)
...@@ -318,10 +318,10 @@ class SimulatedAnnealingPruner(IterativePruner): ...@@ -318,10 +318,10 @@ class SimulatedAnnealingPruner(IterativePruner):
If keeping the intermediate result, including intermediate model and masks during each iteration. If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]] finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration. The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool speedup : bool
If set True, speed up the model at the end of each iteration to make the pruned model compact. If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor] dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up. If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
Examples Examples
-------- --------
...@@ -340,7 +340,7 @@ class SimulatedAnnealingPruner(IterativePruner): ...@@ -340,7 +340,7 @@ class SimulatedAnnealingPruner(IterativePruner):
def __init__(self, model: Module, config_list: List[Dict], evaluator: Callable[[Module], float], start_temperature: float = 100, def __init__(self, model: Module, config_list: List[Dict], evaluator: Callable[[Module], float], start_temperature: float = 100,
stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35, stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35,
pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: str = '.', keep_intermediate_result: bool = False, pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None): finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None):
task_generator = SimulatedAnnealingTaskGenerator(origin_model=model, task_generator = SimulatedAnnealingTaskGenerator(origin_model=model,
origin_config_list=config_list, origin_config_list=config_list,
start_temperature=start_temperature, start_temperature=start_temperature,
...@@ -352,5 +352,5 @@ class SimulatedAnnealingPruner(IterativePruner): ...@@ -352,5 +352,5 @@ class SimulatedAnnealingPruner(IterativePruner):
if 'traced_optimizer' in pruning_params: if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) evaluator=evaluator, reset_weight=False)
...@@ -239,7 +239,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -239,7 +239,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
low_limit = 0 low_limit = 0
while True: while True:
# This is to speed up finding the legal sparsity. # This is to speedup finding the legal sparsity.
low_limit = (1 - low_limit) * 0.05 + low_limit low_limit = (1 - low_limit) * 0.05 + low_limit
random_sparsity = sorted(np.random.uniform(low_limit, 1, len(op_names))) random_sparsity = sorted(np.random.uniform(low_limit, 1, len(op_names)))
rescaled_sparsity = self._rescale_sparsity(random_sparsity, target_sparsity, op_names) rescaled_sparsity = self._rescale_sparsity(random_sparsity, target_sparsity, op_names)
......
...@@ -198,16 +198,16 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_ ...@@ -198,16 +198,16 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_
The current state means `compact_model` + `compact_model_masks` The current state means `compact_model` + `compact_model_masks`
(i.e., `compact_model_masks` applied on `compact_model`). (i.e., `compact_model_masks` applied on `compact_model`).
The compact model is the origin model after pruning, The compact model is the origin model after pruning,
and it may have different structure with origin_model cause of speed up. and it may have different structure with origin_model cause of speedup.
Parameters Parameters
---------- ----------
origin_model : torch.nn.Module origin_model : torch.nn.Module
The original un-pruned model. The original un-pruned model.
compact_model : torch.nn.Module compact_model : torch.nn.Module
The model after speed up or original model. The model after speedup or original model.
compact_model_masks: Dict[str, Dict[str, Tensor]] compact_model_masks: Dict[str, Dict[str, Tensor]]
The masks applied on the compact model, if the original model have been speed up, this should be {}. The masks applied on the compact model, if the original model have been speedup, this should be {}.
config_list : List[Dict] config_list : List[Dict]
The config_list used by pruning the original model. The config_list used by pruning the original model.
......
...@@ -47,7 +47,7 @@ class TorchGraph: ...@@ -47,7 +47,7 @@ class TorchGraph:
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model user wants to speed up The model user wants to speedup
dummy_input : pytorch tensor dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in The dummy input for ```jit.trace```, users should put it on right device before pass in
traced_model : torch._C.torch.jit.TopLevelTracedModule traced_model : torch._C.torch.jit.TopLevelTracedModule
......
...@@ -10,7 +10,7 @@ class BaseModelSpeedup: ...@@ -10,7 +10,7 @@ class BaseModelSpeedup:
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to speed up by quantization. The model to speedup by quantization.
config : dict config : dict
Config recording bit number and name of layers. Config recording bit number and name of layers.
""" """
......
...@@ -37,7 +37,7 @@ def _setattr(model, name, module): ...@@ -37,7 +37,7 @@ def _setattr(model, name, module):
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to speed up by quantization The model to speedup by quantization
name : str name : str
name of pytorch module name of pytorch module
module : torch.nn.Module module : torch.nn.Module
...@@ -98,7 +98,7 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na ...@@ -98,7 +98,7 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to speed up by quantization The model to speedup by quantization
config : dict config : dict
Config recording bits number and name of layers Config recording bits number and name of layers
input_shape : tuple input_shape : tuple
......
...@@ -232,7 +232,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup): ...@@ -232,7 +232,7 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to speed up by quantization. The model to speedup by quantization.
input_shape : tuple input_shape : tuple
The input shape of model, shall pass it to torch.onnx.export. The input shape of model, shall pass it to torch.onnx.export.
config : dict config : dict
......
...@@ -29,7 +29,7 @@ class ModelSpeedup: ...@@ -29,7 +29,7 @@ class ModelSpeedup:
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model user wants to speed up The model user wants to speedup
dummy_input : pytorch tensor, tuple of tensor, list of tensor dummy_input : pytorch tensor, tuple of tensor, list of tensor
Note: The first dimension of the dummy_input should be the batchsize. Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right The dummy input for ```jit.trace```, users should put it on the right
...@@ -499,7 +499,7 @@ class ModelSpeedup: ...@@ -499,7 +499,7 @@ class ModelSpeedup:
second, replace modules. second, replace modules.
""" """
_logger.info("start to speed up the model") _logger.info("start to speedup the model")
self.initialize_speedup() self.initialize_speedup()
training = self.bound_model.training training = self.bound_model.training
# set to the evaluation mode # set to the evaluation mode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment