Unverified Commit dc277501 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Flax memory efficient attention (#2889)



* add use_memory_efficient params placeholder

* test

* add memory efficient attention jax

* add memory efficient attention jax

* newline

* forgot dot

* Rename use_memory_efficient

* Keep dtype last.

* Actually use key_chunk_size

* Rename symbol

* Apply style

* Rename use_memory_efficient

* Keep dtype last

* Pass `use_memory_efficient_attention` in `from_pretrained`

* Move JAX memory efficient attention to attention_flax.

* Simple test.

* style

---------
Co-authored-by: default avatarmuhammad_hanif <muhammad_hanif@sofcograha.co.id>
Co-authored-by: default avatarMuhHanif <48muhhanif@gmail.com>
parent 0df47efe
...@@ -12,10 +12,110 @@ ...@@ -12,10 +12,110 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import math
import flax.linen as nn import flax.linen as nn
import jax
import jax.numpy as jnp import jax.numpy as jnp
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
num_kv, num_heads, k_features = key.shape[-3:]
v_features = value.shape[-1]
key_chunk_size = min(key_chunk_size, num_kv)
query = query / jnp.sqrt(k_features)
@functools.partial(jax.checkpoint, prevent_cse=False)
def summarize_chunk(query, key, value):
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
max_score = jax.lax.stop_gradient(max_score)
exp_weights = jnp.exp(attn_weights - max_score)
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
max_score = jnp.einsum("...qhk->...qh", max_score)
return (exp_values, exp_weights.sum(axis=-1), max_score)
def chunk_scanner(chunk_idx):
# julienne key array
key_chunk = jax.lax.dynamic_slice(
operand=key,
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
)
# julienne value array
value_chunk = jax.lax.dynamic_slice(
operand=value,
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
)
return summarize_chunk(query, key_chunk, value_chunk)
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
max_diffs = jnp.exp(chunk_max - global_max)
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
chunk_weights *= max_diffs
all_values = chunk_values.sum(axis=0)
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
return all_values / all_weights
def jax_memory_efficient_attention(
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
):
r"""
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
https://github.com/AminRezaei0x443/memory-efficient-attention
Args:
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
numerical precision for computation
query_chunk_size (`int`, *optional*, defaults to 1024):
chunk size to divide query array value must divide query_length equally without remainder
key_chunk_size (`int`, *optional*, defaults to 4096):
chunk size to divide key and value array value must divide key_value_length equally without remainder
Returns:
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
"""
num_q, num_heads, q_features = query.shape[-3:]
def chunk_scanner(chunk_idx, _):
# julienne query array
query_chunk = jax.lax.dynamic_slice(
operand=query,
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
)
return (
chunk_idx + query_chunk_size, # unused ignore it
_query_chunk_attention(
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
),
)
_, res = jax.lax.scan(
f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
)
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
class FlaxAttention(nn.Module): class FlaxAttention(nn.Module):
r""" r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
...@@ -29,6 +129,8 @@ class FlaxAttention(nn.Module): ...@@ -29,6 +129,8 @@ class FlaxAttention(nn.Module):
Hidden states dimension inside each head Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0): dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate Dropout rate
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
...@@ -37,6 +139,7 @@ class FlaxAttention(nn.Module): ...@@ -37,6 +139,7 @@ class FlaxAttention(nn.Module):
heads: int = 8 heads: int = 8
dim_head: int = 64 dim_head: int = 64
dropout: float = 0.0 dropout: float = 0.0
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -77,13 +180,38 @@ class FlaxAttention(nn.Module): ...@@ -77,13 +180,38 @@ class FlaxAttention(nn.Module):
key_states = self.reshape_heads_to_batch_dim(key_proj) key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj) value_states = self.reshape_heads_to_batch_dim(value_proj)
# compute attentions if self.use_memory_efficient_attention:
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) query_states = query_states.transpose(1, 0, 2)
attention_scores = attention_scores * self.scale key_states = key_states.transpose(1, 0, 2)
attention_probs = nn.softmax(attention_scores, axis=2) value_states = value_states.transpose(1, 0, 2)
# this if statement create a chunk size for each layer of the unet
# the chunk size is equal to the query_length dimension of the deepest layer of the unet
flatten_latent_dim = query_states.shape[-3]
if flatten_latent_dim % 64 == 0:
query_chunk_size = int(flatten_latent_dim / 64)
elif flatten_latent_dim % 16 == 0:
query_chunk_size = int(flatten_latent_dim / 16)
elif flatten_latent_dim % 4 == 0:
query_chunk_size = int(flatten_latent_dim / 4)
else:
query_chunk_size = int(flatten_latent_dim)
hidden_states = jax_memory_efficient_attention(
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
)
hidden_states = hidden_states.transpose(1, 0, 2)
else:
# compute attentions
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)
# attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
# attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states) hidden_states = self.proj_attn(hidden_states)
return hidden_states return hidden_states
...@@ -108,6 +236,8 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -108,6 +236,8 @@ class FlaxBasicTransformerBlock(nn.Module):
Whether to only apply cross attention. Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
""" """
dim: int dim: int
n_heads: int n_heads: int
...@@ -115,12 +245,17 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -115,12 +245,17 @@ class FlaxBasicTransformerBlock(nn.Module):
dropout: float = 0.0 dropout: float = 0.0
only_cross_attention: bool = False only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False
def setup(self): def setup(self):
# self attention (or cross_attention if only_cross_attention is True) # self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.attn1 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
)
# cross attention # cross attention
self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.attn2 = FlaxAttention(
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
...@@ -169,6 +304,8 @@ class FlaxTransformer2DModel(nn.Module): ...@@ -169,6 +304,8 @@ class FlaxTransformer2DModel(nn.Module):
only_cross_attention (`bool`, defaults to `False`): tbd only_cross_attention (`bool`, defaults to `False`): tbd
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
""" """
in_channels: int in_channels: int
n_heads: int n_heads: int
...@@ -178,6 +315,7 @@ class FlaxTransformer2DModel(nn.Module): ...@@ -178,6 +315,7 @@ class FlaxTransformer2DModel(nn.Module):
use_linear_projection: bool = False use_linear_projection: bool = False
only_cross_attention: bool = False only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False
def setup(self): def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
...@@ -202,6 +340,7 @@ class FlaxTransformer2DModel(nn.Module): ...@@ -202,6 +340,7 @@ class FlaxTransformer2DModel(nn.Module):
dropout=self.dropout, dropout=self.dropout,
only_cross_attention=self.only_cross_attention, only_cross_attention=self.only_cross_attention,
dtype=self.dtype, dtype=self.dtype,
use_memory_efficient_attention=self.use_memory_efficient_attention,
) )
for _ in range(self.depth) for _ in range(self.depth)
] ]
......
...@@ -37,6 +37,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -37,6 +37,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
Number of attention heads of each spatial transformer block Number of attention heads of each spatial transformer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`): add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
...@@ -48,6 +50,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -48,6 +50,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
add_downsample: bool = True add_downsample: bool = True
use_linear_projection: bool = False use_linear_projection: bool = False
only_cross_attention: bool = False only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -72,6 +75,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -72,6 +75,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
depth=1, depth=1,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention, only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype, dtype=self.dtype,
) )
attentions.append(attn_block) attentions.append(attn_block)
...@@ -172,6 +176,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -172,6 +176,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
Number of attention heads of each spatial transformer block Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`): add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
...@@ -184,6 +190,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -184,6 +190,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
add_upsample: bool = True add_upsample: bool = True
use_linear_projection: bool = False use_linear_projection: bool = False
only_cross_attention: bool = False only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -209,6 +216,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -209,6 +216,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
depth=1, depth=1,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention, only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype, dtype=self.dtype,
) )
attentions.append(attn_block) attentions.append(attn_block)
...@@ -311,6 +319,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -311,6 +319,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Number of attention blocks layers Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype` Parameters `dtype`
""" """
...@@ -319,6 +329,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -319,6 +329,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
num_layers: int = 1 num_layers: int = 1
attn_num_head_channels: int = 1 attn_num_head_channels: int = 1
use_linear_projection: bool = False use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
def setup(self): def setup(self):
...@@ -341,6 +352,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -341,6 +352,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
d_head=self.in_channels // self.attn_num_head_channels, d_head=self.in_channels // self.attn_num_head_channels,
depth=1, depth=1,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype, dtype=self.dtype,
) )
attentions.append(attn_block) attentions.append(attn_block)
......
...@@ -88,6 +88,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -88,6 +88,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos (`bool`, *optional*, defaults to `True`): flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
""" """
...@@ -111,6 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -111,6 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
flip_sin_to_cos: bool = True flip_sin_to_cos: bool = True
freq_shift: int = 0 freq_shift: int = 0
use_memory_efficient_attention: bool = False
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors # init input tensors
...@@ -169,6 +172,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -169,6 +172,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
add_downsample=not is_final_block, add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype, dtype=self.dtype,
) )
else: else:
...@@ -190,6 +194,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -190,6 +194,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
dropout=self.dropout, dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1], attn_num_head_channels=attention_head_dim[-1],
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -217,6 +222,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -217,6 +222,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
dropout=self.dropout, dropout=self.dropout,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype, dtype=self.dtype,
) )
else: else:
......
...@@ -296,6 +296,7 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -296,6 +296,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False) from_pt = kwargs.pop("from_pt", False)
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
dtype = kwargs.pop("dtype", None) dtype = kwargs.pop("dtype", None)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
...@@ -451,7 +452,12 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -451,7 +452,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model = cached_folder loaded_sub_model = cached_folder
if issubclass(class_obj, FlaxModelMixin): if issubclass(class_obj, FlaxModelMixin):
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) loaded_sub_model, loaded_params = load_method(
loadable_folder,
from_pt=from_pt,
use_memory_efficient_attention=use_memory_efficient_attention,
dtype=dtype,
)
params[name] = loaded_params params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
if from_pt: if from_pt:
......
...@@ -215,3 +215,47 @@ class FlaxPipelineTests(unittest.TestCase): ...@@ -215,3 +215,47 @@ class FlaxPipelineTests(unittest.TestCase):
if jax.device_count() == 8: if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3 assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1 assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
def test_jax_memory_efficient_attention(self):
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=jnp.bfloat16,
safety_checker=None,
)
params = replicate(params)
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
slice = images[2, 0, 256, 10:17, 1]
# With memory efficient attention
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=jnp.bfloat16,
safety_checker=None,
use_memory_efficient_attention=True,
)
params = replicate(params)
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images
assert images_eff.shape == (num_samples, 1, 512, 512, 3)
slice_eff = images[2, 0, 256, 10:17, 1]
# I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum`
# over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now.
assert abs(slice_eff - slice).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment