Commit 95ac4f88 authored by Sze-qq's avatar Sze-qq Committed by binmakeswell
Browse files

[NFC] polish...


[NFC] polish colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py code style (#1829)
Co-authored-by: default avatarsiqi <siqi@siqis-MacBook-Pro.local>
parent 5da03c93
...@@ -3,9 +3,9 @@ import warnings ...@@ -3,9 +3,9 @@ import warnings
from functools import reduce from functools import reduce
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
...@@ -71,19 +71,19 @@ class ConvHandler(OperatorHandler): ...@@ -71,19 +71,19 @@ class ConvHandler(OperatorHandler):
Argument: Argument:
sharding_size_forward(int): The forward activation will be divided sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions. into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions. be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions. into sharding_size_weight number partions.
Return: Return:
memory_cost(Tuple[float]): Memory cost per device with this memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward memory cost, and the second element of this tuple is backward
memory cost. memory cost.
memory_cost_forward(float): Memory cost of forward activation per memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy. device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy. per device with this specific strategy.
''' '''
# compute the memory cost of this strategy # compute the memory cost of this strategy
...@@ -541,14 +541,14 @@ class ConvHandler(OperatorHandler): ...@@ -541,14 +541,14 @@ class ConvHandler(OperatorHandler):
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input) strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ]) strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2], conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
conv_handler.register_strategy_into_strategies_vector() conv_handler.register_strategy_into_strategies_vector()
for strategy in conv_handler.strategies_vector: for strategy in conv_handler.strategies_vector:
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}') print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
Output: Output:
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]} S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]} S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment