Commit 6eb5b12c authored by zhuwenwen's avatar zhuwenwen
Browse files

update qwen3_vl conv layout (ncdhw to ndhwc)

parent 7f417161
......@@ -27,6 +27,7 @@ from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Any, Callable, Optional, Union
import os
import numpy as np
import torch
import torch.nn as nn
......@@ -116,6 +117,8 @@ class Qwen3_VisionPatchEmbed(nn.Module):
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
self.patch_size)
if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
x = x.to(memory_format=torch.channels_last_3d)
x = self.proj(x).view(L, self.hidden_size)
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment