Commit 6eb5b12c authored by zhuwenwen's avatar zhuwenwen
Browse files

update qwen3_vl conv layout (ncdhw to ndhwc)

parent 7f417161
...@@ -27,6 +27,7 @@ from collections.abc import Iterable, Mapping, Sequence ...@@ -27,6 +27,7 @@ from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import os
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -116,6 +117,8 @@ class Qwen3_VisionPatchEmbed(nn.Module): ...@@ -116,6 +117,8 @@ class Qwen3_VisionPatchEmbed(nn.Module):
L, C = x.shape L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
self.patch_size) self.patch_size)
if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
x = x.to(memory_format=torch.channels_last_3d)
x = self.proj(x).view(L, self.hidden_size) x = self.proj(x).view(L, self.hidden_size)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment