Unverified Commit f024e003 authored by Alexander Pivovarov's avatar Alexander Pivovarov Committed by GitHub
Browse files

Fix typos (#2715)


Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2120b4ee
...@@ -69,7 +69,7 @@ class AttentionBlock(nn.Module): ...@@ -69,7 +69,7 @@ class AttentionBlock(nn.Module):
self.value = nn.Linear(channels, channels) self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, 1) self.proj_attn = nn.Linear(channels, channels, bias=True)
self._use_memory_efficient_attention_xformers = False self._use_memory_efficient_attention_xformers = False
self._attention_op = None self._attention_op = None
......
...@@ -344,7 +344,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -344,7 +344,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers. of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
""" """
count = len(self.attn_processors.keys()) count = len(self.attn_processors.keys())
...@@ -379,24 +379,24 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -379,24 +379,24 @@ class ControlNetModel(ModelMixin, ConfigMixin):
Args: Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`. must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"): if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim) sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children(): for child in module.children():
fn_recursive_retrieve_slicable_dims(child) fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers # retrieve number of attention layers
for module in self.children(): for module in self.children():
fn_recursive_retrieve_slicable_dims(module) fn_recursive_retrieve_sliceable_dims(module)
num_slicable_layers = len(sliceable_head_dims) num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto": if slice_size == "auto":
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
...@@ -404,9 +404,9 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -404,9 +404,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
slice_size = [dim // 2 for dim in sliceable_head_dims] slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max": elif slice_size == "max":
# make smallest slice possible # make smallest slice possible
slice_size = num_slicable_layers * [1] slice_size = num_sliceable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims): if len(slice_size) != len(sliceable_head_dims):
raise ValueError( raise ValueError(
......
...@@ -575,7 +575,7 @@ class ModelMixin(torch.nn.Module): ...@@ -575,7 +575,7 @@ class ModelMixin(torch.nn.Module):
raise ValueError( raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize" " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct." " those weights or else make sure your checkpoint file is correct."
) )
...@@ -591,7 +591,7 @@ class ModelMixin(torch.nn.Module): ...@@ -591,7 +591,7 @@ class ModelMixin(torch.nn.Module):
set_module_tensor_to_device(model, param_name, param_device, value=param) set_module_tensor_to_device(model, param_name, param_device, value=param)
else: # else let accelerate handle loading and dispatching. else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map # Load weights and dispatch according to the device_map
# by deafult the device_map is None and the weights are loaded on the CPU # by default the device_map is None and the weights are loaded on the CPU
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype) accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
loading_info = { loading_info = {
......
...@@ -418,7 +418,7 @@ class ResnetBlock2D(nn.Module): ...@@ -418,7 +418,7 @@ class ResnetBlock2D(nn.Module):
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift. "ada_group" for a stronger conditioning with scale and shift.
kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`): use_in_shortcut (`bool`, *optional*, default to `True`):
......
...@@ -105,7 +105,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -105,7 +105,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.attention_head_dim = attention_head_dim self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration # Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None self.is_input_vectorized = num_vector_embeds is not None
...@@ -198,7 +198,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -198,7 +198,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 4. Define output layers # 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous: if self.is_input_continuous:
# TODO: should use out_channels for continous projections # TODO: should use out_channels for continuous projections
if use_linear_projection: if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels) self.proj_out = nn.Linear(inner_dim, in_channels)
else: else:
...@@ -223,7 +223,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -223,7 +223,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
""" """
Args: Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
......
...@@ -59,7 +59,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -59,7 +59,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
obj:`(32, 32, 64)`): Tuple of block output channels. obj:`(32, 32, 64)`): Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. act_fn (`str`, *optional*, defaults to None): optional activation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block.
downsample_each_block (`int`, *optional*, defaults to False: downsample_each_block (`int`, *optional*, defaults to False:
......
...@@ -331,7 +331,7 @@ class SelfAttention1d(nn.Module): ...@@ -331,7 +331,7 @@ class SelfAttention1d(nn.Module):
self.key = nn.Linear(self.channels, self.channels) self.key = nn.Linear(self.channels, self.channels)
self.value = nn.Linear(self.channels, self.channels) self.value = nn.Linear(self.channels, self.channels)
self.proj_attn = nn.Linear(self.channels, self.channels, 1) self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
self.dropout = nn.Dropout(dropout_rate, inplace=True) self.dropout = nn.Dropout(dropout_rate, inplace=True)
......
...@@ -2684,7 +2684,7 @@ class KAttentionBlock(nn.Module): ...@@ -2684,7 +2684,7 @@ class KAttentionBlock(nn.Module):
dropout=dropout, dropout=dropout,
bias=attention_bias, bias=attention_bias,
cross_attention_dim=None, cross_attention_dim=None,
cross_attention_norm=None, cross_attention_norm=False,
) )
# 2. Cross-Attn # 2. Cross-Attn
......
...@@ -197,7 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -197,7 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
timestep_input_dim = block_out_channels[0] timestep_input_dim = block_out_channels[0]
else: else:
raise ValueError( raise ValueError(
f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
) )
self.time_embedding = TimestepEmbedding( self.time_embedding = TimestepEmbedding(
...@@ -391,7 +391,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -391,7 +391,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers. of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
""" """
count = len(self.attn_processors.keys()) count = len(self.attn_processors.keys())
...@@ -425,24 +425,24 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -425,24 +425,24 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
Args: Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`. must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"): if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim) sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children(): for child in module.children():
fn_recursive_retrieve_slicable_dims(child) fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers # retrieve number of attention layers
for module in self.children(): for module in self.children():
fn_recursive_retrieve_slicable_dims(module) fn_recursive_retrieve_sliceable_dims(module)
num_slicable_layers = len(sliceable_head_dims) num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto": if slice_size == "auto":
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
...@@ -450,9 +450,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -450,9 +450,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
slice_size = [dim // 2 for dim in sliceable_head_dims] slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max": elif slice_size == "max":
# make smallest slice possible # make smallest slice possible
slice_size = num_slicable_layers * [1] slice_size = num_sliceable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims): if len(slice_size) != len(sliceable_head_dims):
raise ValueError( raise ValueError(
...@@ -515,7 +515,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -515,7 +515,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
returning a tuple, the first element is the sample tensor. returning a tuple, the first element is the sample tensor.
""" """
# By default samples have to be AT least a multiple of the overall upsampling factor. # By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size # However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary. # on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers default_overall_up_factor = 2**self.num_upsamplers
......
...@@ -1351,7 +1351,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1351,7 +1351,7 @@ class DiffusionPipeline(ConfigMixin):
Args: Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`): slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`. must be a multiple of `slice_size`.
""" """
......
...@@ -287,7 +287,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -287,7 +287,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
timestep_input_dim = block_out_channels[0] timestep_input_dim = block_out_channels[0]
else: else:
raise ValueError( raise ValueError(
f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
) )
self.time_embedding = TimestepEmbedding( self.time_embedding = TimestepEmbedding(
...@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -481,7 +481,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers. of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
""" """
count = len(self.attn_processors.keys()) count = len(self.attn_processors.keys())
...@@ -515,24 +515,24 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -515,24 +515,24 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
Args: Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`. must be a multiple of `slice_size`.
""" """
sliceable_head_dims = [] sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"): if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim) sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children(): for child in module.children():
fn_recursive_retrieve_slicable_dims(child) fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers # retrieve number of attention layers
for module in self.children(): for module in self.children():
fn_recursive_retrieve_slicable_dims(module) fn_recursive_retrieve_sliceable_dims(module)
num_slicable_layers = len(sliceable_head_dims) num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto": if slice_size == "auto":
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
...@@ -540,9 +540,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -540,9 +540,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
slice_size = [dim // 2 for dim in sliceable_head_dims] slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max": elif slice_size == "max":
# make smallest slice possible # make smallest slice possible
slice_size = num_slicable_layers * [1] slice_size = num_sliceable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims): if len(slice_size) != len(sliceable_head_dims):
raise ValueError( raise ValueError(
...@@ -605,7 +605,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -605,7 +605,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor. returning a tuple, the first element is the sample tensor.
""" """
# By default samples have to be AT least a multiple of the overall upsampling factor. # By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size # However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary. # on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers default_overall_up_factor = 2**self.num_upsamplers
......
...@@ -223,23 +223,23 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -223,23 +223,23 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
output = model(**inputs_dict) output = model(**inputs_dict)
assert output is not None assert output is not None
def test_model_slicable_head_dim(self): def test_model_sliceable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16) init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
def check_slicable_dim_attr(module: torch.nn.Module): def check_sliceable_dim_attr(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"): if hasattr(module, "set_attention_slice"):
assert isinstance(module.sliceable_head_dim, int) assert isinstance(module.sliceable_head_dim, int)
for child in module.children(): for child in module.children():
check_slicable_dim_attr(child) check_sliceable_dim_attr(child)
# retrieve number of attention layers # retrieve number of attention layers
for module in model.children(): for module in model.children():
check_slicable_dim_attr(module) check_sliceable_dim_attr(module)
def test_special_attn_proc(self): def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module): class AttnEasyProc(torch.nn.Module):
...@@ -658,7 +658,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -658,7 +658,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
torch.cuda.reset_max_memory_allocated() torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
# there are 32 slicable layers # there are 32 sliceable layers
slice_list = 16 * [2, 3] slice_list = 16 * [2, 3]
unet = self.get_unet_model() unet = self.get_unet_model()
unet.set_attention_slice(slice_list) unet.set_attention_slice(slice_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment