test.py 7.66 KB
Newer Older
chenpangpang's avatar
chenpangpang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
import time

# 检查JAX是否使用了GPU
print("JAX版本:", jax.__version__)
print("可用设备:", jax.devices())
print("当前使用的设备:", jax.devices()[0])
print("设备类型:", jax.devices()[0].device_kind)

# 1. 基本的GPU矩阵运算示例
def basic_gpu_operations():
    print("\n=== 基础GPU矩阵运算 ===")
    
    # 创建大型矩阵(这些会自动在GPU上分配)
    key = random.PRNGKey(42)
    size = 5000  # 大型矩阵
    
    # 生成随机矩阵
    A = random.normal(key, (size, size))
    B = random.normal(key, (size, size))
    
    # GPU矩阵乘法
    start = time.time()
    C = jnp.dot(A, B)  # 这会自动在GPU上执行
    compute_time = time.time() - start
    
    print(f"矩阵大小: {size}x{size}")
    print(f"GPU矩阵乘法耗时: {compute_time:.4f}秒")
    
    # 检查结果
    print(f"结果矩阵C的形状: {C.shape}")
    print(f"C的平均值: {jnp.mean(C):.6f}")
    print(f"C的标准差: {jnp.std(C):.6f}")
    
    return C

# 2. 使用JIT编译加速
def jit_acceleration():
    print("\n=== JIT编译加速示例 ===")
    
    # 定义一个复杂函数
    def complex_function(x):
        return jnp.sum(jnp.sin(x) * jnp.exp(-x**2) + jnp.log1p(jnp.abs(x)))
    
    # 创建输入数据
    key = random.PRNGKey(0)
    x = random.normal(key, (10000, 1000))
    
    # 不使用JIT
    start = time.time()
    result_no_jit = complex_function(x)
    time_no_jit = time.time() - start
    
    # 使用JIT编译
    jitted_function = jit(complex_function)
    
    # 第一次调用会有编译开销
    start = time.time()
    result_jit = jitted_function(x)
    time_first_call = time.time() - start
    
    # 第二次调用(已经编译)
    start = time.time()
    result_jit = jitted_function(x)
    time_second_call = time.time() - start
    
    print(f"不使用JIT: {time_no_jit:.4f}秒")
    print(f"JIT首次调用(含编译): {time_first_call:.4f}秒")
    print(f"JIT后续调用: {time_second_call:.4f}秒")
    print(f"加速比: {time_no_jit/time_second_call:.2f}倍")
    
    return result_jit

# 3. 自动微分和GPU计算
def autodiff_gpu():
    print("\n=== GPU上的自动微分 ===")
    
    # 定义损失函数
    def loss(params, x, y):
        w, b = params
        predictions = jnp.dot(x, w) + b
        return jnp.mean((predictions - y)**2)
    
    # 生成模拟数据
    key = random.PRNGKey(1)
    key1, key2, key3 = random.split(key, 3)
    
    n_samples = 10000
    n_features = 100
    
    X = random.normal(key1, (n_samples, n_features))
    true_w = random.normal(key2, (n_features, 1))
    true_b = random.normal(key3, (1,))
    Y = jnp.dot(X, true_w) + true_b + 0.1 * random.normal(key1, (n_samples, 1))
    
    # 初始化参数
    params = (random.normal(key1, (n_features, 1)), jnp.zeros((1,)))
    
    # 计算梯度的GPU函数
    grad_loss = jit(grad(loss, argnums=0))
    
    # 在GPU上计算梯度
    start = time.time()
    gradients = grad_loss(params, X, Y)
    grad_time = time.time() - start
    
    print(f"数据形状: X={X.shape}, Y={Y.shape}")
    print(f"GPU梯度计算耗时: {grad_time:.4f}秒")
    print(f"权重梯度形状: {gradients[0].shape}")
    print(f"偏置梯度形状: {gradients[1].shape}")
    
    return gradients

# 4. 批量处理示例
def batch_processing():
    print("\n=== GPU批量处理 ===")
    
    # 定义一个处理单个样本的函数
    def process_sample(x):
        # 一些复杂的变换
        return jnp.sum(jnp.tanh(x) * jnp.cos(x))
    
    # 使用vmap进行向量化
    batched_process = vmap(process_sample)
    
    # 创建批量数据
    key = random.PRNGKey(3)
    batch_size = 10000
    sample_size = 1000
    
    batch_data = random.normal(key, (batch_size, sample_size))
    
    # 逐个处理(慢)
    def sequential_process(data):
        results = []
        for i in range(data.shape[0]):
            results.append(process_sample(data[i]))
        return jnp.array(results)
    
    # 计时
    start = time.time()
    sequential_results = sequential_process(batch_data)
    sequential_time = time.time() - start
    
    # 使用vmap批量处理(在GPU上并行)
    start = time.time()
    batched_results = batched_process(batch_data)
    batched_time = time.time() - start
    
    print(f"批量大小: {batch_size}, 样本大小: {sample_size}")
    print(f"逐个处理耗时: {sequential_time:.4f}秒")
    print(f"批量处理耗时: {batched_time:.4f}秒")
    print(f"加速比: {sequential_time/batched_time:.2f}倍")
    
    # 验证结果一致性
    error = jnp.max(jnp.abs(sequential_results - batched_results))
    print(f"结果最大误差: {error:.10f}")
    
    return batched_results

# 5. 高级示例:简单的神经网络层
def neural_network_example():
    print("\n=== GPU上的神经网络层 ===")
    
    # 定义一个简单的神经网络层
    def dense_layer(params, x):
        w, b = params
        return jnp.dot(x, w) + b
    
    # ReLU激活函数
    def relu(x):
        return jnp.maximum(0, x)
    
    # 两层神经网络
    def two_layer_net(params, x):
        w1, b1, w2, b2 = params
        hidden = relu(jnp.dot(x, w1) + b1)
        output = jnp.dot(hidden, w2) + b2
        return output
    
    # 初始化参数
    key = random.PRNGKey(4)
    key1, key2, key3, key4 = random.split(key, 4)
    
    input_size = 784  # MNIST图像大小
    hidden_size = 128
    output_size = 10   # 10个类别
    
    # 初始化权重
    w1 = random.normal(key1, (input_size, hidden_size)) * jnp.sqrt(2.0 / input_size)
    b1 = jnp.zeros(hidden_size)
    w2 = random.normal(key2, (hidden_size, output_size)) * jnp.sqrt(2.0 / hidden_size)
    b2 = jnp.zeros(output_size)
    
    params = (w1, b1, w2, b2)
    
    # 创建批量输入
    batch_size = 1024
    x_batch = random.normal(key3, (batch_size, input_size))
    
    # JIT编译神经网络
    jitted_net = jit(two_layer_net)
    
    # 前向传播
    start = time.time()
    outputs = jitted_net(params, x_batch)
    inference_time = time.time() - start
    
    print(f"网络架构: {input_size} -> {hidden_size} -> {output_size}")
    print(f"批量大小: {batch_size}")
    print(f"GPU推理耗时: {inference_time:.6f}秒")
    print(f"输出形状: {outputs.shape}")
    print(f"每个样本的平均输出: {jnp.mean(outputs, axis=0)}")
    
    return outputs

def main():
    print("="*50)
    print("JAX GPU Demo - 开始执行")
    print("="*50)
    
    # 运行所有示例
    try:
        # 示例1: 基础GPU运算
        result1 = basic_gpu_operations()
        
        # 示例2: JIT加速
        result2 = jit_acceleration()
        
        # 示例3: 自动微分
        result3 = autodiff_gpu()
        
        # 示例4: 批量处理
        result4 = batch_processing()
        
        # 示例5: 神经网络
        result5 = neural_network_example()
        
        print("\n" + "="*50)
        print("所有示例执行完成!")
        print("="*50)
        
        # 验证GPU确实在工作
        print("\n验证GPU使用情况:")
        with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
            # 执行一个操作来查看跟踪
            test_matrix = jnp.ones((1000, 1000))
            test_result = jnp.dot(test_matrix, test_matrix.T)
            print(f"跟踪测试完成,结果形状: {test_result.shape}")
            
    except Exception as e:
        print(f"执行出错: {e}")
        print("\n注意:请确保已经正确安装JAX GPU版本")
        print("安装命令: pip install --upgrade 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")

if __name__ == "__main__":
    main()