print("[ASP][Info] permutation_search_kernels can be imported.")
exceptImportError:
print("[ASP][Warning] permutation_search_kernels cannot be imported.")
print("[ASP][Warning] If you want to accelerate the permutation search process by GPU, please build APEX by following the instructions at https://github.com/NVIDIA/apex/blob/master/apex/contrib/sparsity/README.md")
"""This function is used to set the permutation needed parameters from ASP class."""
print("\n[set_permutation_params_from_asp] Set permutation needed parameters")
cls.__model=model
cls.__sparse_parameters=sparse_parameters
cls.__all_parameters=all_parameters
@classmethod
defset_identical_seed(cls,identical_seed=1):
print("\n[set_identical_seed] Set the identical seed: {:} for all GPUs to make sure the same results generated in permutation search".format(identical_seed))
print("\n[apply_offline_permutation] node_name: \'{:}\', node module type: \'{:}\', need to do offline permutation in K and C dims.".format(node_name,node_module_type))
print("[apply_offline_permutation][warning] node_name: \'{:}\', its real parents have trouble in permutation in K dim, so skip the offline permutation in C dim.".format(node_name,node_module_type))
print("\n[apply_offline_permutation] node_name: \'{:}\', node module type: \'{:}\', need to do offline permutation in K dim.".format(node_name,node_module_type))
else:# for BN, if the previous Conv cannot do permutation in K dim, then no need to do permutation in K dim for this BN
print("[apply_offline_permutation][warning] node_name: \'{:}\', its real parents have trouble in permutation in K dim, so skip the offline permutation in K dim.".format(node_name,node_module_type))
eliffx_graph[node_name]['permutation_type']=='C':# last layer FC/Conv
print("\n[apply_offline_permutation] node_name: \'{:}\', node module type: \'{:}\', need to do offline permutation in C dim.".format(node_name,node_module_type))
print("[apply_offline_permutation][warning] node_name: \'{:}\', its real parents have trouble in permutation in K dim, so skip the offline permutation in C dim.".format(node_name,node_module_type))
ifcls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph,save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir,'./model_graph_apply_offline_permutation.json'))# save the intermediate graph as JSON file for debugging
returnfx_graph
@classmethod
deftransfer_to_dense_mask(cls):
"""Call this method to transfer the sparse mask to all-one dense mask."""
"""This function is used to fetch the permutation sequence value in K dim from the unique_siblings record."""
# K_permutation_sequence is its real_children's corresponding 'permutation_sequence' value stored in the fx_graph.get('unique_siblings') item which contains real_children name
# we have the assumption that all the real children are in one unique_sibling group, so should share the same permutation_sequence value
print("[apply_permutation_in_C_dim] find the node: \'{:}\' in cls.__sparse_parameters, succeed to apply permutation in C dim.".format(node_name))
is_node_in_sparse_parameters=True
temp_weight=torch.zeros_like(p)
temp_weight.copy_(p[:,permutation_sequence,...])
p.data.copy_(temp_weight)
success_permutation=True
ifis_node_in_sparse_parameters==False:
# A special case: if the node itself not in sparse_module_names but one of its real_siblings in sparse_module_names, then the node will not do the permutation search, but it may need to apply the offline permutation in C dim according to the searched permutation sequence from its real_siblings in sparse_module_names
print("[apply_permutation_in_C_dim] cannot find the node: \'{:}\' in cls.__sparse_parameters, but can find in cls.__all_parameters.".format(node_name))
print("[apply_permutation_in_C_dim] cannot find the node: \'{:}\' in cls.__sparse_parameters, after trying with cls.__all_parameters, succeed to apply permutation in C dim.".format(node_name))
except:
success_permutation=False
print("[apply_permutation_in_C_dim] cannot find the node: \'{:}\' in cls.__sparse_parameters, after trying with cls.__all_parameters, still fail to apply permutation in C dim.".format(node_name))
print("[apply_permutation_in_K_dim] find the node: \'{:}\' with \'{:}\' in cls.__all_parameters, may succeed to apply permutation in K dim.".format(node_name,p_name))
is_node_in_all_parameters=True
temp_weight=torch.zeros_like(p)
ifp.shape[0]!=len(permutation_sequence):
print("[apply_permutation_in_K_dim][warning] the node: \'{:}\' with shape: \'{:}\', cannot match the size of permutation sequence with len: \'{:}\', fail to apply permutation in K dim.".format(node_name,p.shape,len(permutation_sequence)))
success_permutation=False
else:
print("[apply_permutation_in_K_dim] the node: \'{:}\' with shape: \'{:}\', can match the size of permutation sequence with len: \'{:}\', succeed to apply permutation in K dim.".format(node_name,p.shape,len(permutation_sequence)))
temp_weight.copy_(p[permutation_sequence,...])
p.data.copy_(temp_weight)
success_permutation=True
ifis_node_in_all_parameters==False:
print("[apply_permutation_in_K_dim] cannot find the node: \'{:}\' in cls.__all_parameters, fail to apply permutation in K dim.".format(node_name))
print("\n[build_offline_permutation_graph] Take {:.4f} seconds to finish search_for_good_permutation function.".format(duration_search_for_good_permutation))
# Please notice the apply_offline_permutation step cannot fold into the above search_for_good_permutation step.
# Because the real_parent node needs to offline permutation in K direction according to the searched permutation sequence from its real_children.
# However, when we search_for_good_permutation for the node, its real_children have not been handled by search_for_good_permutation.
ifcls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph_after_search_for_good_permutation,save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir,'./model_graph_build_offline_permutation_graph.json'))# save the intermediate graph as JSON file for debugging
1. search for the good permutation sequence for each node weight, or each siblings_group weights by calling the permutation search kernels as ASP extension.
2. add the searched permutation sequence for each node according to the whole network graph built with Torch.FX."""
print("\n[search_for_good_permutation] Search for the good permutation sequence for each node according to the whole network graph built with Torch.FX")
forunique_siblings_groupinunique_siblings_groups:# loop through all unique siblings groups that must share a permutation sequence
print("\n[search_for_good_permutation] this unique_siblings_group has {:} real siblings: \'{:}\', with module type: \'{:}\'.".format(len(unique_siblings_group),unique_siblings_group,unique_siblings_groups_module_type[item_index]))
item_index=item_index+1
# concat the weight for layers in the same unique_siblings_group
print("[search_for_good_permutation] find the node: \'{:}\' in cls.__sparse_parameters, module type match: \'{:}\'.".format(node_name,node_module_type==module_type_from_sparse_parameters))
is_node_in_sparse_parameters=True
node_weight=torch.zeros_like(p)
node_weight.copy_(p)
# Need to handle the concat for layers with different R & S
print("[search_for_good_permutation] cannot find the node: \'{:}\' in cls.__sparse_parameters, no need to merge its weight for permutation.".format(node_name))
else:
ifmatrix_group==None:
matrix_group=node_weight
else:
try:
ifmatrix_group.dim()==node_weight.dim():
matrix_group=torch.cat((matrix_group,node_weight),dim=0)# concat the weights in K dimension, and keep the same C dimension
else:# e.g. when try to merge the Conv and FC layers
print("[search_for_good_permutation] matrix_group dim: {:} is not matched with node_weight dim: {:}.".format(matrix_group.dim(),node_weight.dim()))
print("[search_for_good_permutation] matrix_group shape: \'{:}\' is not matched with node_weight shape: \'{:}\'.".format(matrix_group.size(),node_weight.size()))
print("[search_for_good_permutation] matrix_group shape: \'{:}\' is now matched with node_weight shape: \'{:}\'.".format(matrix_group.size(),node_weight.size()))
matrix_group=torch.cat((matrix_group,node_weight),dim=0)# concat the weights in K dimension, and keep the same C dimension
except:
print("[search_for_good_permutation][warning] cannot merge the weight for node: \'{:}\', with its weight shape: \'{:}\', the matrix_group shape: \'{:}\'.".format(node_name,node_weight.size(),matrix_group.size()))
continue
print("[search_for_good_permutation] have merged the weight for node: \'{:}\', with its weight shape: \'{:}\', the matrix_group shape: \'{:}\'.".format(node_name,node_weight.size(),matrix_group.size()))
ifmatrix_group==None:# cannot find the node: \'{:}\' in cls.__sparse_parameters
input_channel_num=0
print("\n[search_for_good_permutation] init the all-zero list with length \'{:}\' for permutation search sequence of this unique_siblings_group.".format(input_channel_num))
print("[search_for_good_permutation] no need to search the permutation_sequence for empty matrix_group.")
print("\n[search_for_good_permutation] init the all-zero list with length \'{:}\' for permutation search sequence of this unique_siblings_group.".format(input_channel_num))
print("\n[search_for_good_permutation] Original element abs sum: {:}, Pruned element abs sum: {:}, Diff ratio: {:}".format(original_magnitude,pruned_magnitude,diff_ratio))
ifdiff_ratio<epsilon:
print("[search_for_good_permutation] Original element abs sum is almost same as the pruned element abs sum, further permutation search will not help, skipping!")
print("[search_for_good_permutation] Change the all-zero permutation search sequence to a sequential permutation search sequence.")
print("[search_for_good_permutation] Original element abs sum is different from the pruned element abs sum, further permutation search will help, continue with the permutation search!")
# call the permutation search CUDA kernels as ASP extension.
# users can provide prefer search strategy by providing a valid 'search_options' as a dictionary,
# or users can implement their customized 'accelerated_search_for_good_permutation' function.
print("[search_for_good_permutation] Change to Progressive Channel Swap Search with {} seconds limitation, because the {} is too large and will leading too long permutation search time with Exhaustive Search.".format(search_options['progressive_search_time_limit'],input_channel_num))
print("[search_for_good_permutation] Take {:.4f} seconds to finish accelerated_search_for_good_permutation function.".format(duration_accelerated_search_for_good_permutation))
cls.save_graph_to_json(fx_graph,save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir,'./model_graph_search_for_good_permutation.json'))# save the intermediate graph as JSON file for debugging
returnfx_graph
@classmethod
definit_permutation_flag(cls,fx_graph):
"""This function is used to init the permutation flag for each node according to the whole network graph built with Torch.FX."""
print("\n[init_permutation_flag] Init the permutation flag for each node according to the whole network graph built with Torch.FX")
# for the first (due to it is connected to 'x' node or itself is not in sparse_module_names) or not NVIDIA's TC compatiable Conv/FC, only permutate the K direction
ifis_node_real_children_has_group_conv==False:
fx_graph[node_name]['permutation_type']='K'
fx_graph[node_name]['k_permuted']='False'
else:# if node real_children contains Group Conv, disable the permutation for node in 'K' dim
else:# if node real_parents contains Group Conv or does not need permutation in 'K' dim, disable the permutation for node in 'K' dim
fx_graph[node_name]['permutation_type']='None'
else:
fx_graph[node_name]['permutation_type']='None'
else:
fx_graph[node_name]['permutation_type']='None'
# A special case: if the node itself not in sparse_module_names but one of its real_siblings in sparse_module_names, then the node will not do the permutation search, but it may need to apply the offline permutation in C dim according to the searched permutation sequence from its real_siblings in sparse_module_names
# We make it as the post-processing, because if we add this to the previous logic, will make it too complex
print("[init_permutation_flag] node_name: \'{:}\', its original permutation: \'{:}\' already includes C dim, no need to do No.1 post-processing change.".format(node_name,node_original_permutation_type))
elifnode_original_permutation_type=='None':
fx_graph[node_name]['permutation_type']='C'
print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'C'.".format(node_name,node_original_permutation_type))
print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'KC'.".format(node_name,node_original_permutation_type))
print("[init_permutation_flag] node_name: \'{:}\', its original permutation: \'{:}\' already includes K dim, no need to do No.2 post-processing change.".format(node_name,node_original_permutation_type))
elifnode_original_permutation_type=='None':
fx_graph[node_name]['permutation_type']='K'
print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'K'.".format(node_name,node_original_permutation_type))
elifnode_original_permutation_type=='C':
fx_graph[node_name]['permutation_type']='KC'
print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'KC'.".format(node_name,node_original_permutation_type))
ifcls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph,save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir,'./model_graph_init_permutation_flag.json'))# save the intermediate graph as JSON file for debugging
returnfx_graph
@classmethod
defextract_all_unique_siblings(cls,fx_graph):
"""This function is used to extrat all unique siblings for the whole network graph built with Torch.FX."""
print("\n[extract_all_unique_siblings] Extract all unique siblings for the whole network graph built with Torch.FX")
all_unique_siblings_name=[]
all_unique_siblings_module_type=[]
fornode_nameinfx_graph.keys():
fx_graph[node_name]['node_type']='network_node'# use the 'node_type' to divide the real nodes apart from the auxiliary info node, like 'unique_siblings' node
print("[extract_all_unique_siblings] node_name: \'{:}\', node module type: \'{:}\', has {:} real siblings: \'{:}\'.".format(node_name,node_module_type,len(node_real_siblings),node_real_siblings))
# for the two duplicated siblings lists, the node names included should be the same.
# If the node name is already included in one of the unique_siblings_name list, which means the real_siblings of this node is duplicated with the unique_siblings_name list.
# Otherwise, we should insert the [real_siblings + node_name] as a new unique_siblings_name list.
cls.save_graph_to_json(fx_graph,save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir,'./model_graph_extract_all_unique_siblings.json'))# save the intermediate graph as JSON file for debugging
returnfx_graph
@classmethod
deffind_real_siblings(cls,fx_graph):
"""This function is used to find all siblings for each node according to the whole network graph built with Torch.FX.
we need to find siblings recursively, because siblings may have siblings via other parents we don't know about.
"""
print("\n[find_real_siblings] Find all siblings for each node according to the whole network graph built with Torch.FX")
print("[find_real_siblings] node_name: \'{:}\', has one real sibling: \'{:}\', its real sibling module type: \'{:}\'.".format(node_name,real_child_item,sibling_module_type))
cls.save_graph_to_json(fx_graph,save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir,'./model_graph_find_real_siblings.json'))# save the intermediate graph as JSON file for debugging
print("[recursive_find_real_children] node_name: \'{:}\', has one real child: \'{:}\', its real child module type: \'{:}\'.".format(node_name,child_name,child_module_type))
print("[recursive_find_real_children] node_name: \'{:}\', its child: \'{:}\' with module type: \'{:}\', needs recursive search.".format(node_name,child_name,child_module_type))
sub_node_need_recursive_search.append(child_name)
else:
print("[recursive_find_real_children] node_name: \'{:}\', its child: \'{:}\' with no module type, is not its real child.".format(node_name,child_name))
"""This function is used to find the real children for each node according to the whole network graph built with Torch.FX.
For example:
The real children of Conv is the subsequent Conv/FC.
The real children of BN or other no-need-permutataion layers is the subsequent Conv/FC.
"""
print("\n[find_real_children] Find the real children for each node according to the whole network graph built with Torch.FX")
fromsysimportversion_info
ifversion_info.major==3andversion_info.minor>=8:
reversible_fx_graph_keys=fx_graph.keys()
else:# 'dict_keys' object is not reversible in previous of Python 3.8
reversible_fx_graph_keys=list(fx_graph.keys())
fornode_nameinreversed(reversible_fx_graph_keys):# as the optimization, we need to find the real children from back to front, to use the already saved 'real_children'
print("\n[find_real_children] node_name: \'{:}\', node module type: \'{:}\', children num: {:}, recursive to find real children.".format(node_name,node_module_type,len(node_children)))
else:# Quick method, but cannot get the real children for no-need-permutataion layers like BN
print("\n[find_real_children] node_name: \'{:}\', node module type: \'{:}\', children num: {:}, can directly find real children.".format(node_name,node_module_type,len(node_children)))
# if the node is in the 'real_parents' list of the other node, then the other node is the real children for this node
print("[find_real_children] node_name: \'{:}\', has one real child: \'{:}\', its real child module type: \'{:}\'.".format(node_name,other_node_name,child_module_type))
cls.save_graph_to_json(fx_graph,save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir,'./model_graph_find_real_children.json'))# save the intermediate graph as JSON file for debugging
returnfx_graph
@classmethod
deffind_real_parents(cls,fx_graph):
"""This function is used to find the real parents for each node according to the whole network graph built with Torch.FX.
For example:
The real parent of BN is the previous Conv/FC.
The real parent of Conv is the previous Conv/FC.
"""
print("\n[find_real_parents] Find the real parents for each node according to the whole network graph built with Torch.FX")
print("[find_real_parents] node_name: \'{:}\', has one real parent: \'{:}\', its real parent module type: \'{:}\'.".format(node_name,parent_name,parent_module_type))
print("[find_real_parents] node_name: \'{:}\', has one/several real parent(s): \'{:}\', its real parent module type: \'{:}\'.".format(node_name,fx_graph[parent_name]['real_parents'],fx_graph[parent_name]['real_parents_module_type']))
cls.save_graph_to_json(fx_graph,save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir,'./model_graph_find_real_parent.json'))# save the intermediate graph as JSON file for debugging
exceptValueError:# support the none standard version
torch_version_minimum=torch_version.split('.')[2]
print("[build_fx_graph] The torch version is: {}, version major is: {}, version minor is: {}, version minimum is: {}".format(torch_version,torch_version_major,torch_version_minor,torch_version_minimum))
print("[build_fx_graph] This is the \'call_function\' node: {:}, its parent list: {:}, its children list: {:}".format(converted_node_name,node_parent,node_children))
print("[build_fx_graph] This is the \'call_method\' node: {:}, its parent list: {:}, its children list: {:}".format(converted_node_name,node_parent,node_children))
# check whether the converted_node_name is same as node.target, especially for ReLU case
ifconverted_node_name!=node.target:
print("[build_fx_graph][warning] The target name from Torch.FX is \'{:}\', the manually converted node name is \'{:}\', not the same one, choose the converted node name".format(node.target,converted_node_name))
# assume the modules share the same target name have the same type, because converted_node_name may not be obtained by model.named_modules(), like some ReLU (defined in forward function)
node_type=module_name_type_dict[node.target]
print("[build_fx_graph] This is the \'call_module\' node: {:}, its parent list: {:}, its children list: {:}, its type: {:}".format(converted_node_name,node_parent,node_children,node_type))